| @@ -0,0 +1,31 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Extensions | |||
| { | |||
| public static class DictionaryExtension | |||
| { | |||
| public static void Deconstruct<T1, T2>(this KeyValuePair<T1, T2> pair, out T1 first, out T2 second) | |||
| { | |||
| first = pair.Key; | |||
| second = pair.Value; | |||
| } | |||
| public static void Update<T1, T2>(this Dictionary<T1, T2> dic, IDictionary<T1, T2> other) | |||
| { | |||
| foreach(var (key, value) in other) | |||
| { | |||
| dic[key] = value; | |||
| } | |||
| } | |||
| public static T2 GetOrDefault<T1, T2>(this Dictionary<T1, T2> dic, T1 key, T2 defaultValue) | |||
| { | |||
| if (dic.ContainsKey(key)) | |||
| { | |||
| return dic[key]; | |||
| } | |||
| return defaultValue; | |||
| } | |||
| } | |||
| } | |||
| @@ -21,7 +21,7 @@ namespace Tensorflow | |||
| { | |||
| public partial class tensorflow | |||
| { | |||
| GradientTape _tapeSet; | |||
| internal GradientTape _tapeSet; | |||
| /// <summary> | |||
| /// Record operations for automatic differentiation. | |||
| @@ -14,6 +14,8 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using Tensorflow.Operations; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class tensorflow | |||
| @@ -79,5 +81,10 @@ namespace Tensorflow | |||
| num_split: num_split, | |||
| axis: axis, | |||
| name: name); | |||
| public Tensor ensure_shape(Tensor x, Shape shape, string name = null) | |||
| { | |||
| return gen_ops.ensure_shape(x, shape, name); | |||
| } | |||
| } | |||
| } | |||
| @@ -61,7 +61,7 @@ namespace Tensorflow | |||
| public static extern void TF_SetAttrBool(IntPtr desc, string attr_name, bool value); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status); | |||
| public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, ulong proto_len, SafeStatusHandle status); | |||
| /// <summary> | |||
| /// Set `num_dims` to -1 to represent "unknown rank". | |||
| @@ -22,6 +22,7 @@ using System.ComponentModel; | |||
| using System.Diagnostics; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using Tensorflow.Operations; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -107,6 +107,12 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| public void Release() | |||
| { | |||
| _handle.Dispose(); | |||
| _handle = null; | |||
| } | |||
| public override string ToString() | |||
| => $"0x{_handle.DangerousGetHandle():x16}"; | |||
| @@ -161,7 +161,7 @@ public static class CheckPointUtils | |||
| internal static IEnumerable<Trackable> _objects_with_attributes(IEnumerable<Trackable> full_list) | |||
| { | |||
| return full_list.TakeWhile(x => | |||
| return full_list.Where(x => | |||
| { | |||
| var saveables = x.gather_saveables_for_checkpoint(); | |||
| return saveables is not null && saveables.Count > 0; | |||
| @@ -109,6 +109,7 @@ namespace Tensorflow.Checkpoint | |||
| TrackableObjectGraph.Types.TrackableObject trackable_object = new(); | |||
| trackable_object.SlotVariables.AddRange(td.slot_variable_proto); | |||
| trackable_object.Children.AddRange(td.children_proto); | |||
| object_graph_proto.Nodes.Add(trackable_object); | |||
| } | |||
| return object_graph_proto; | |||
| } | |||
| @@ -14,9 +14,11 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using Google.Protobuf; | |||
| using System; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using Tensorflow.Common.Extensions; | |||
| namespace Tensorflow.Contexts | |||
| { | |||
| @@ -25,12 +27,93 @@ namespace Tensorflow.Contexts | |||
| /// </summary> | |||
| public sealed partial class Context | |||
| { | |||
| public ConfigProto Config { get; set; } = new ConfigProto | |||
| protected Device.PhysicalDevice[] _physical_devices; | |||
| protected Dictionary<Device.PhysicalDevice, int> _physical_device_to_index; | |||
| ConfigProto _config; | |||
| public ConfigProto Config | |||
| { | |||
| GpuOptions = new GPUOptions | |||
| get | |||
| { | |||
| _initialize_physical_devices(); | |||
| var config = new ConfigProto(); | |||
| if(_config is not null) | |||
| { | |||
| config.MergeFrom(_config); | |||
| } | |||
| config.LogDevicePlacement = _log_device_placement; | |||
| config.DeviceCount["CPU"] = 0; | |||
| config.DeviceCount["GPU"] = 0; | |||
| foreach(var dev in _physical_devices) | |||
| { | |||
| if (config.DeviceCount.ContainsKey(dev.DeviceType)) | |||
| { | |||
| config.DeviceCount[dev.DeviceType] += 1; | |||
| } | |||
| else | |||
| { | |||
| config.DeviceCount[dev.DeviceType] = 1; | |||
| } | |||
| } | |||
| var gpu_options = _compute_gpu_options(); | |||
| config.GpuOptions = GPUOptions.Parser.ParseFrom(gpu_options.ToByteArray()); | |||
| return config; | |||
| } | |||
| set | |||
| { | |||
| _config = value; | |||
| } | |||
| } | |||
| protected void _initialize_physical_devices(bool reinitialize = false) | |||
| { | |||
| if(!reinitialize && _physical_devices is not null) | |||
| { | |||
| return; | |||
| } | |||
| var devs = list_physical_devices(); | |||
| _physical_devices = devs.Select(d => new Device.PhysicalDevice() | |||
| { | |||
| DeviceName = d.DeviceName, | |||
| DeviceType = d.DeviceType | |||
| }).ToArray(); | |||
| _physical_device_to_index = _physical_devices.Select((p, i) => new KeyValuePair<Device.PhysicalDevice, int>(p, i)) | |||
| .ToDictionary(x => x.Key, x => x.Value); | |||
| _import_config(); | |||
| } | |||
| protected void _import_config() | |||
| { | |||
| if(_config is null) | |||
| { | |||
| return; | |||
| } | |||
| if(!_config.DeviceCount.TryGetValue("CPU", out var num_cpus)) | |||
| { | |||
| num_cpus = 1; | |||
| } | |||
| if(num_cpus != 1) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| } | |||
| }; | |||
| var gpus = _physical_devices.Where(d => d.DeviceType == "GPU"); | |||
| if(gpus.Count() == 0) | |||
| { | |||
| return; | |||
| } | |||
| if(!_config.DeviceCount.TryGetValue("GPU", out var gpu_count)) | |||
| { | |||
| gpu_count = 0; | |||
| } | |||
| // TODO(Rinne): implement it. | |||
| } | |||
| ConfigProto MergeConfig() | |||
| { | |||
| @@ -38,7 +38,26 @@ namespace Tensorflow.Contexts | |||
| public string ScopeName { get; set; } = ""; | |||
| bool initialized = false; | |||
| ContextSwitchStack context_switches; | |||
| public FunctionCallOptions FunctionCallOptions { get; } | |||
| protected FunctionCallOptions _function_call_options; | |||
| public FunctionCallOptions FunctionCallOptions | |||
| { | |||
| get | |||
| { | |||
| if(_function_call_options is null) | |||
| { | |||
| var config = Config; | |||
| _function_call_options = new FunctionCallOptions() | |||
| { | |||
| Config = config | |||
| }; | |||
| } | |||
| return _function_call_options; | |||
| } | |||
| set | |||
| { | |||
| _function_call_options = value; | |||
| } | |||
| } | |||
| SafeContextHandle _handle; | |||
| @@ -62,7 +81,6 @@ namespace Tensorflow.Contexts | |||
| if (initialized) | |||
| return; | |||
| Config = MergeConfig(); | |||
| FunctionCallOptions.Config = Config; | |||
| var config_str = Config.ToByteArray(); | |||
| var opts = new ContextOptions(); | |||
| @@ -167,11 +185,29 @@ namespace Tensorflow.Contexts | |||
| return c_api.TFE_ContextHasFunction(_handle, name); | |||
| } | |||
| public void add_function(SafeFuncGraphHandle fn) | |||
| { | |||
| ensure_initialized(); | |||
| Status status = new(); | |||
| c_api.TFE_ContextAddFunction(_handle, fn, status); | |||
| status.Check(true); | |||
| } | |||
| public void remove_function(string name) | |||
| { | |||
| ensure_initialized(); | |||
| Status status = new(); | |||
| c_api.TFE_ContextRemoveFunction(_handle, name, status); | |||
| status.Check(true); | |||
| } | |||
| public void add_function_def(FunctionDef fdef) | |||
| { | |||
| ensure_initialized(); | |||
| var fdef_string = fdef.ToString(); | |||
| c_api.TFE_ContextAddFunctionDef(_handle, fdef_string, fdef_string.Length); | |||
| var fdef_string = fdef.ToByteArray(); | |||
| Status status = new Status(); | |||
| c_api.TFE_ContextAddFunctionDef(_handle, fdef_string, (ulong)fdef_string.Length, status); | |||
| status.Check(true); | |||
| } | |||
| public void restore_mode() | |||
| @@ -9,10 +9,11 @@ namespace Tensorflow.Contexts | |||
| public class FunctionCallOptions | |||
| { | |||
| public ConfigProto Config { get; set; } | |||
| public string ExecutorType { get; set; } | |||
| public string config_proto_serialized() | |||
| public ByteString config_proto_serialized() | |||
| { | |||
| return Config.ToByteString().ToStringUtf8(); | |||
| return Config.ToByteString(); | |||
| } | |||
| } | |||
| } | |||
| @@ -17,6 +17,7 @@ | |||
| using System; | |||
| using System.Linq; | |||
| using Tensorflow.Contexts; | |||
| using Tensorflow.Functions; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Eager | |||
| @@ -0,0 +1,53 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Operations; | |||
| namespace Tensorflow.Eager | |||
| { | |||
| internal static class backprop_util | |||
| { | |||
| // TODO: add quantized_dtypes (after being supported). | |||
| private static HashSet<TF_DataType> _trainable_dtypes = new HashSet<TF_DataType>(new TF_DataType[] | |||
| { | |||
| dtypes.float16, dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128, | |||
| dtypes.resource, dtypes.variant, TF_DataType.TF_BFLOAT16 | |||
| }); | |||
| public static bool IsTrainable(Tensor tensor) | |||
| { | |||
| var dtype = _DTypeFromTensor(tensor); | |||
| return _trainable_dtypes.Contains(dtype); | |||
| } | |||
| public static bool IsTrainable(TF_DataType dtype) | |||
| { | |||
| return _trainable_dtypes.Contains(dtype); | |||
| } | |||
| private static TF_DataType _DTypeFromTensor(Tensor tensor) | |||
| { | |||
| var dtype = tensor.dtype; | |||
| if(dtype.as_base_dtype() == TF_DataType.TF_VARIANT) | |||
| { | |||
| CppShapeInferenceResult.Types.HandleData handle_data; | |||
| if (tensor is EagerTensor) | |||
| { | |||
| handle_data = tensor.HandleData; | |||
| } | |||
| else | |||
| { | |||
| handle_data = handle_data_util.get_resource_handle_data(tensor); | |||
| } | |||
| if(handle_data is not null && handle_data.IsSet && handle_data.ShapeAndType is not null && | |||
| handle_data.ShapeAndType.Count > 0) | |||
| { | |||
| var first_type = handle_data.ShapeAndType[0].Dtype; | |||
| if(first_type != DataType.DtInvalid && handle_data.ShapeAndType.All(x => x.Dtype == first_type)) | |||
| { | |||
| return first_type.as_tf_dtype(); | |||
| } | |||
| } | |||
| } | |||
| return dtype; | |||
| } | |||
| } | |||
| } | |||
| @@ -31,7 +31,7 @@ namespace Tensorflow | |||
| public static extern void TFE_ContextOptionsSetConfig(SafeContextOptionsHandle opts, byte[] proto, ulong proto_len, SafeStatusHandle status); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFE_ContextAddFunctionDef(SafeContextHandle ctx, string serialized_function_def, int size); | |||
| public static extern void TFE_ContextAddFunctionDef(SafeContextHandle ctx, byte[] serialized_function_def, ulong size, SafeStatusHandle status); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFE_ContextOptionsSetDevicePlacementPolicy(SafeContextOptionsHandle opts, ContextDevicePlacementPolicy device_policy); | |||
| @@ -280,7 +280,7 @@ namespace Tensorflow | |||
| public static extern void TFE_OpSetAttrIntList(SafeEagerOpHandle op, string attr_name, long[] values, int num_values); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFE_OpSetAttrValueProto(SafeEagerOpHandle op, string attr_name, IMessage[] proto, int proto_len, SafeStatusHandle status); | |||
| public static extern void TFE_OpSetAttrValueProto(IntPtr op, string attr_name, IntPtr proto, ulong proto_len, SafeStatusHandle status); | |||
| /// <summary> | |||
| /// | |||
| @@ -1,6 +0,0 @@ | |||
| namespace Tensorflow.Framework.Models | |||
| { | |||
| class ScopedTFFunction | |||
| { | |||
| } | |||
| } | |||
| @@ -0,0 +1,22 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Framework | |||
| { | |||
| internal class ScopedTFFunction | |||
| { | |||
| SafeFuncGraphHandle _handle; | |||
| string _name; | |||
| public ScopedTFFunction(SafeFuncGraphHandle func, string name) | |||
| { | |||
| _handle = func; | |||
| _name = name; | |||
| } | |||
| public SafeFuncGraphHandle Get() | |||
| { | |||
| return _handle; | |||
| } | |||
| } | |||
| } | |||
| @@ -4,6 +4,7 @@ using System.Diagnostics; | |||
| using System.Security.Cryptography; | |||
| using System.Text; | |||
| using Tensorflow.Graphs; | |||
| using Tensorflow.Common.Extensions; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.CppShapeInferenceResult.Types; | |||
| @@ -64,7 +65,7 @@ namespace Tensorflow.Framework | |||
| { | |||
| output_names[ops.tensor_id(func_graph.get_tensor_by_name(tensor_name))] = ret_arg_def.Name; | |||
| } | |||
| // TODO(Rinne): func_graph._output_names = output_names | |||
| func_graph._output_names = output_names; | |||
| func_graph.Exit(); | |||
| return func_graph; | |||
| @@ -154,9 +155,17 @@ namespace Tensorflow.Framework | |||
| foreach(var node_def in fdef.NodeDef) | |||
| { | |||
| var graph = default_graph; | |||
| // TODO(Rinne): The `Graph` lacks `_functions`, needed to be implemented in the future. | |||
| while(graph.OuterGraph is not null) | |||
| while (true) | |||
| { | |||
| if(graph is null) | |||
| { | |||
| break; | |||
| } | |||
| var f = graph.Functions.GetOrDefault(node_def.Op, null); | |||
| if(f is not null && graph.OuterGraph is null) | |||
| { | |||
| break; | |||
| } | |||
| graph = graph.OuterGraph; | |||
| } | |||
| @@ -4,6 +4,7 @@ using System.Diagnostics; | |||
| using System.Linq; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Framework.Models; | |||
| using Tensorflow.Gradients; | |||
| using Tensorflow.Graphs; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Util; | |||
| @@ -19,7 +20,7 @@ namespace Tensorflow.Functions | |||
| protected IEnumerable<Tensor> _captured_inputs; | |||
| internal FuncGraph func_graph; | |||
| protected DelayedRewriteGradientFunctions _delayed_rewrite_functions; | |||
| protected Dictionary<string, string> _attrs; | |||
| protected Dictionary<string, AttrValue> _attrs; | |||
| protected FunctionSpec _function_spec; | |||
| protected FunctionSpec _pre_initialized_function_spec = null; | |||
| protected EagerDefinedFunction _inference_function; | |||
| @@ -29,22 +30,25 @@ namespace Tensorflow.Functions | |||
| public string Name => _delayed_rewrite_functions.Forward().Name; | |||
| public Tensor[] Outputs; | |||
| public Tensor[] Outputs => func_graph.Outputs; | |||
| public Type ReturnType; | |||
| public TensorSpec[] OutputStructure; | |||
| public IEnumerable<string> ArgKeywords { get; set; } | |||
| public long NumPositionArgs { get; set; } | |||
| public FunctionDef FunctionDef => _delayed_rewrite_functions.Forward().Definition; | |||
| public Tensor[] FlatStructuredOutputs => func_graph.FlatStructuredOutputs; | |||
| public IEnumerable<IVariableV1> Variables => func_graph.Variables; | |||
| public IEnumerable<IVariableV1> TrainableVariables => func_graph.TrainableVariables; | |||
| public ConcreteFunction(string name) | |||
| { | |||
| func_graph = new FuncGraph(name); | |||
| _captured_inputs = func_graph.external_captures; | |||
| _attrs= new Dictionary<string, string>(); | |||
| _attrs= new Dictionary<string, AttrValue>(); | |||
| _set_infer_function(); | |||
| } | |||
| public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs = null) | |||
| public ConcreteFunction(FuncGraph graph, Dictionary<string, AttrValue> attrs = null) | |||
| { | |||
| func_graph = graph; | |||
| _captured_inputs = func_graph.external_captures; | |||
| @@ -70,7 +74,7 @@ namespace Tensorflow.Functions | |||
| null); | |||
| func_graph.Exit(); | |||
| _captured_inputs = func_graph.external_captures; | |||
| _attrs = new Dictionary<string, string>(); | |||
| _attrs = new Dictionary<string, AttrValue>(); | |||
| _set_infer_function(); | |||
| } | |||
| @@ -93,7 +97,7 @@ namespace Tensorflow.Functions | |||
| null); | |||
| func_graph.Exit(); | |||
| _captured_inputs = func_graph.external_captures; | |||
| _attrs = new Dictionary<string, string>(); | |||
| _attrs = new Dictionary<string, AttrValue>(); | |||
| _set_infer_function(); | |||
| } | |||
| @@ -160,27 +164,20 @@ namespace Tensorflow.Functions | |||
| } | |||
| if (!executing_eagerly) | |||
| { | |||
| // TODO(Rinne): add the check | |||
| } | |||
| tensor_inputs.AddRange(captured_inputs); | |||
| tensor_inputs.AddRange(captured_inputs); | |||
| args = tensor_inputs.ToArray(); | |||
| var possible_gradient_type = tf.Runner.MustRecordGradient() ? 1 : 0; | |||
| var possible_gradient_type = gradients_util.PossibleTapeGradientTypes(args); | |||
| // No tape is watching; skip to running the function. | |||
| if (possible_gradient_type == 0 && executing_eagerly) | |||
| if (possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_NONE && executing_eagerly) | |||
| { | |||
| return _build_call_outputs(_inference_function.Call(args)); | |||
| //var attrs = new object[] | |||
| //{ | |||
| // "executor_type", "", | |||
| // "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() | |||
| //}; | |||
| //return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs); | |||
| } | |||
| if (forward_backward == null) | |||
| forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly); | |||
| forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly); | |||
| var (forward_function, args_with_tangents) = forward_backward.Forward(); | |||
| Tensors flat_outputs = null; | |||
| if (executing_eagerly) | |||
| @@ -189,8 +186,12 @@ namespace Tensorflow.Functions | |||
| } | |||
| else | |||
| { | |||
| // TODO(Rinne): add `default_graph._override_gradient_function`. | |||
| flat_outputs = forward_function.Call(args_with_tangents); | |||
| tf_with(default_graph._override_gradient_function(new Dictionary<string, Func<Operation, object[], Tensor[]>>(){ | |||
| { "PartitionedCall", _get_gradient_function() }, { "StatefulPartitionedCall", _get_gradient_function() } | |||
| }), _ => | |||
| { | |||
| flat_outputs = forward_function.Call(args_with_tangents); | |||
| }); | |||
| } | |||
| forward_backward.Record(flat_outputs); | |||
| return _build_call_outputs(flat_outputs); | |||
| @@ -215,7 +216,8 @@ namespace Tensorflow.Functions | |||
| TangentInfo input_tangents; | |||
| if (executing_eagerly) | |||
| { | |||
| throw new NotImplementedException(); | |||
| // TODO(Rinne): check if it needs to be implemented. | |||
| input_tangents = new TangentInfo(); | |||
| } | |||
| else | |||
| { | |||
| @@ -239,7 +241,12 @@ namespace Tensorflow.Functions | |||
| } | |||
| // TODO(Rinne): add arg "input_tagents" for ForwardBackwardCall. | |||
| return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: false); | |||
| return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: tf.Runner.MustRecordGradient()); | |||
| } | |||
| internal void set_variables(IEnumerable<IVariableV1> variables) | |||
| { | |||
| func_graph.Variables = variables; | |||
| } | |||
| internal void _set_infer_function() | |||
| @@ -274,6 +281,11 @@ namespace Tensorflow.Functions | |||
| }; | |||
| } | |||
| internal Func<Operation, object[], Tensor[]> _get_gradient_function() | |||
| { | |||
| return _delayed_rewrite_functions._rewrite_forward_and_call_backward; | |||
| } | |||
| private Tensors _build_call_outputs(Tensors result) | |||
| { | |||
| // TODO(Rinne): dwal with `func_graph.structured_outputs` | |||
| @@ -9,18 +9,27 @@ using Tensorflow.Eager; | |||
| using Tensorflow.Graphs; | |||
| using Tensorflow.Operations; | |||
| using Tensorflow.Util; | |||
| using Tensorflow.Common.Extensions; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow.Framework; | |||
| using System.Buffers; | |||
| using Tensorflow.Gradients; | |||
| namespace Tensorflow.Functions | |||
| { | |||
| public class EagerDefinedFunction | |||
| public class EagerDefinedFunction: IDisposable | |||
| { | |||
| public int _num_outputs; | |||
| FuncGraph _func_graph; | |||
| FuncGraph _graph; | |||
| FunctionDef _definition; | |||
| OpDef _signature; | |||
| string _name; | |||
| Tensor[] _func_graph_outputs; | |||
| internal ScopedTFFunction _c_func; | |||
| internal Tensor[] _func_graph_outputs; | |||
| internal string _grad_func_name; | |||
| internal Func<Operation, Tensor[], Tensor[]> csharp_grad_func; | |||
| internal EagerDefinedFunction _grad_func; | |||
| internal bool _registered_on_context = false; | |||
| public string Name => _name; | |||
| public DataType[] OutputTypes { get; protected set; } | |||
| public Shape[] OutputShapes { get; protected set; } | |||
| @@ -47,48 +56,93 @@ namespace Tensorflow.Functions | |||
| return _signature; | |||
| } | |||
| } | |||
| public EagerDefinedFunction(string name, FuncGraph graph, | |||
| public unsafe EagerDefinedFunction(string name, FuncGraph graph, | |||
| Tensors inputs, Tensors outputs, | |||
| Dictionary<string, string> attrs) | |||
| Dictionary<string, AttrValue> attrs) | |||
| { | |||
| var input_ops = inputs.Select(x => x.op).ToArray(); | |||
| var operations = graph.get_operations().Where(x => !input_ops.Contains(x.op)) | |||
| .Select(x => x as Operation).ToArray(); | |||
| var output_names = new string[0]; | |||
| _func_graph = new FuncGraph(graph, name, attrs); | |||
| _func_graph_outputs = new List<Tensor>(outputs).ToArray(); | |||
| _func_graph.ToGraph(operations, inputs, outputs, output_names); | |||
| var graph_output_names = graph._output_names; | |||
| string[] output_names; | |||
| if(graph_output_names is not null && outputs.All(t => graph_output_names.ContainsKey(ops.tensor_id(t)))) | |||
| { | |||
| output_names = outputs.Select(t => graph_output_names[ops.tensor_id(t)]).ToArray(); | |||
| if(output_names.Distinct().Count() != output_names.Length) | |||
| { | |||
| output_names = new string[0]; | |||
| } | |||
| } | |||
| else | |||
| { | |||
| output_names = new string[0]; | |||
| } | |||
| Status status = new Status(); | |||
| var fn = c_api.TF_GraphToFunction(graph.c_graph, | |||
| name, | |||
| false, | |||
| operations.Length, | |||
| operations.Length == 0 ? new IntPtr[0] : operations.Select(x => (IntPtr)x).ToArray(), | |||
| inputs.Length, | |||
| inputs.Select(t => t._as_tf_output()).ToArray(), | |||
| outputs.Length, | |||
| outputs.Select(t => t._as_tf_output()).ToArray(), | |||
| output_names.Length != outputs.Length ? null : output_names, | |||
| IntPtr.Zero, // warning: the control output hasbben totally ignored. | |||
| null, | |||
| status); | |||
| status.Check(true); | |||
| _c_func = new ScopedTFFunction(fn, name); | |||
| foreach(var (attr_name, attr_value) in attrs) | |||
| { | |||
| var serialized = attr_value.ToByteArray(); | |||
| c_api.TF_FunctionSetAttrValueProto(fn, attr_name, serialized, serialized.Length, status); | |||
| status.Check(true); | |||
| } | |||
| var signature = _get_definition().Signature; | |||
| _name = signature.Name; | |||
| // TODO(Rinne): deal with `fn` | |||
| tf_with(ops.init_scope(), s => | |||
| { | |||
| tf.Context.add_function(fn); | |||
| _registered_on_context = true; | |||
| }); | |||
| _num_outputs = signature.OutputArg.Count; | |||
| OutputTypes = signature.OutputArg.Select(x => x.Type).ToArray(); | |||
| OutputShapes = outputs.Select(x => x.shape).ToArray(); | |||
| _func_graph_outputs = new List<Tensor>(outputs).ToArray(); | |||
| csharp_grad_func = null; | |||
| _graph = graph; | |||
| } | |||
| public Tensors Call(Tensors args) | |||
| public unsafe Tensors Call(Tensors args) | |||
| { | |||
| // TODO(Rinne): Add arg `CancellationManager`. | |||
| // TODO(Rinne): Check the arg length. | |||
| var function_call_options = tf.Context.FunctionCallOptions; | |||
| string config; | |||
| if (string.IsNullOrEmpty(function_call_options.config_proto_serialized())) | |||
| if (function_call_options.config_proto_serialized().Length == 0) | |||
| { | |||
| config = function_utils.get_disabled_rewriter_config(); | |||
| config = function_utils.get_disabled_rewriter_config().ToString(); | |||
| } | |||
| else | |||
| { | |||
| config = function_call_options.config_proto_serialized(); | |||
| config = function_call_options.config_proto_serialized().ToString(); | |||
| } | |||
| // TODO(Rinne): executor_type | |||
| config = ""; // TODO(Rinne): revise it. | |||
| string executor_type = function_call_options.ExecutorType ?? ""; | |||
| var executing_eagerly = tf.Context.executing_eagerly(); | |||
| var attrs = new object[] | |||
| { | |||
| "executor_type", "", | |||
| "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() | |||
| "executor_type", executor_type, | |||
| "config_proto", config | |||
| }; | |||
| Tensor[] outputs; | |||
| @@ -103,9 +157,19 @@ namespace Tensorflow.Functions | |||
| } | |||
| else | |||
| { | |||
| tf.GradientTape().stop_recording(); | |||
| outputs = functional_ops.partitioned_call(args, this, OutputTypes, | |||
| executing_eagerly, config, ""); | |||
| if(tf.GetTapeSet().Count == 0) | |||
| { | |||
| outputs = functional_ops.partitioned_call(args, this, OutputTypes, | |||
| executing_eagerly, config, ""); | |||
| } | |||
| else | |||
| { | |||
| var tape = tf.GetTapeSet().Peek(); | |||
| tape.StopRecord(); | |||
| outputs = functional_ops.partitioned_call(args, this, OutputTypes, | |||
| executing_eagerly, config, ""); | |||
| tape.StartRecord(); | |||
| } | |||
| } | |||
| foreach(var (i, func_graph_output) in enumerate(_func_graph_outputs)) | |||
| { | |||
| @@ -141,7 +205,7 @@ namespace Tensorflow.Functions | |||
| { | |||
| g.AddFunction(this); | |||
| } | |||
| foreach(var f in _func_graph.Functions.Values) | |||
| foreach(var f in _graph.Functions.Values) | |||
| { | |||
| if (!g.IsFunction(f.Name)) | |||
| { | |||
| @@ -155,12 +219,15 @@ namespace Tensorflow.Functions | |||
| { | |||
| var buffer = c_api_util.tf_buffer(); | |||
| Status status = new(); | |||
| c_api.TF_FunctionToFunctionDef(_func_graph._func_graph_handle, buffer, status); | |||
| c_api.TF_FunctionToFunctionDef(_c_func.Get(), buffer, status); | |||
| status.Check(true); | |||
| var proto_data = c_api.TF_GetBuffer(buffer); | |||
| FunctionDef function_def = new(); | |||
| function_def.MergeFrom(proto_data.AsSpan<byte>()); | |||
| return function_def; | |||
| return FunctionDef.Parser.ParseFrom(proto_data.AsSpan<byte>()); | |||
| } | |||
| public void Dispose() | |||
| { | |||
| tf.Context.remove_function(Name); | |||
| } | |||
| } | |||
| } | |||
| @@ -17,9 +17,9 @@ namespace Tensorflow.Functions | |||
| public override EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args) | |||
| { | |||
| var outputs = _func_graph.Outputs; | |||
| (_forward, _forward_graph, _backward, _forwardprop_output_indices, _num_forwardprop_outputs) | |||
| (_forward_function, _forward_graph, _backward_function, _forwardprop_output_indices, _num_forwardprop_outputs) | |||
| = BuildFunctionsForOutputs(outputs, inference_args); | |||
| return _forward; | |||
| return _forward_function; | |||
| } | |||
| } | |||
| } | |||
| @@ -10,23 +10,26 @@ namespace Tensorflow | |||
| private IntPtr _handle; | |||
| #pragma warning restore CS0169 // The field 'Function._handle' is never used | |||
| protected Func<Tensors, Tensors> _function; | |||
| protected Func<Tensor[], Tensor[]> _csharp_function; | |||
| protected ConcreteFunction _concrete_variable_creation_fn; | |||
| protected bool _auto_graph; | |||
| protected bool _autograph; | |||
| protected TracingCompiler _variable_creation_fn; | |||
| protected bool _has_initialized; | |||
| public string Name { get; set; } | |||
| public Function(Func<Tensors, Tensors> function, | |||
| public Function(Func<Tensor[], Tensor[]> csharp_function, | |||
| string name, bool auto_graph = true) | |||
| { | |||
| _function = function; | |||
| _csharp_function = csharp_function; | |||
| Name = name; | |||
| _auto_graph = auto_graph; | |||
| _autograph = auto_graph; | |||
| _has_initialized = false; | |||
| } | |||
| public virtual Tensors Apply(Tensors inputs) | |||
| { | |||
| if (_run_functions_eagerly()) | |||
| { | |||
| return _function(inputs); | |||
| return _csharp_function(inputs); | |||
| } | |||
| var result = _call(inputs); | |||
| @@ -35,20 +38,32 @@ namespace Tensorflow | |||
| protected virtual Tensors _call(Tensors inputs) | |||
| { | |||
| _initialize(); | |||
| if (!_has_initialized) | |||
| { | |||
| _initialize(inputs); | |||
| } | |||
| return _concrete_variable_creation_fn.CallFlat(inputs, | |||
| _concrete_variable_creation_fn.CapturedInputs); | |||
| } | |||
| protected TracingCompiler _compiler(Func<Tensor[], Tensor[]> fn) | |||
| { | |||
| var name = nameof(fn); | |||
| return new TracingCompiler(fn, name, autograph: _autograph); | |||
| } | |||
| protected virtual bool _run_functions_eagerly() | |||
| { | |||
| return false; | |||
| } | |||
| private void _initialize() | |||
| private void _initialize(Tensor[] args) | |||
| { | |||
| _variable_creation_fn = _compiler(_csharp_function); | |||
| _variable_creation_fn._name = this.Name; | |||
| _concrete_variable_creation_fn = _variable_creation_fn._get_concrete_function_internal_garbage_collected(args); | |||
| _has_initialized = true; | |||
| } | |||
| } | |||
| } | |||
| @@ -3,8 +3,10 @@ using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Gradients; | |||
| using Tensorflow.Graphs; | |||
| using Tensorflow.NumPy; | |||
| using Tensorflow.Operations; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.tensorflow; | |||
| @@ -22,11 +24,11 @@ namespace Tensorflow.Functions | |||
| protected string _INFERENCE_PREFIX = "__inference_"; | |||
| protected FuncGraph _func_graph; | |||
| protected EagerDefinedFunction _forward; | |||
| protected EagerDefinedFunction _forward_function; | |||
| protected FuncGraph _forward_graph; | |||
| protected List<int> _forwardprop_output_indices; | |||
| protected int _num_forwardprop_outputs; | |||
| protected ConcreteFunction _backward; | |||
| protected ConcreteFunction _backward_function; | |||
| BackwardFunction _backward_function_wrapper; | |||
| public TapeGradientFunctions(FuncGraph func_graph, | |||
| @@ -49,8 +51,8 @@ namespace Tensorflow.Functions | |||
| public virtual void Record(Tensors flat_outputs, Tensors inference_args) | |||
| { | |||
| // TODO(Rinne): add arg `input_tagents`. | |||
| var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs); | |||
| tf.Runner.RecordGradient(_forward.Name, inference_args, new object[0], to_record, | |||
| var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward_function, flat_outputs); | |||
| tf.Runner.RecordGradient(_forward_function.Name, inference_args, new object[0], to_record, | |||
| getBackwardFunction: backward_function); | |||
| } | |||
| @@ -134,46 +136,58 @@ namespace Tensorflow.Functions | |||
| var trainable_indices = new List<int>(); | |||
| foreach(var (index, output) in enumerate(outputs)) | |||
| { | |||
| if (gradients_util.IsTrainable(output)) | |||
| if (backprop_util.IsTrainable(output)) | |||
| { | |||
| trainable_outputs.Add(output); | |||
| trainable_indices.Add(index); | |||
| } | |||
| } | |||
| var gradients_wrt_outputs = new List<Tensor>(); | |||
| var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"); | |||
| var backwards_graph = new FuncGraph(_func_graph.Name); | |||
| backwards_graph.as_default(); | |||
| var gradients_wrt_outputs = new List<Tensor>(); | |||
| foreach (var output in trainable_outputs) | |||
| gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape)); | |||
| { | |||
| var (gradient_shape, gradient_dtype) = default_gradient.shape_and_dtype(output); | |||
| var gradient_placeholder = tf.placeholder(gradient_dtype, gradient_shape); | |||
| gradients_wrt_outputs.Add(gradient_placeholder); | |||
| handle_data_util.copy_handle_data(output, gradient_placeholder); | |||
| } | |||
| var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), | |||
| _func_graph.Inputs, | |||
| grad_ys: gradients_wrt_outputs.ToArray(), | |||
| src_graph: _func_graph); | |||
| _func_graph.Inputs, | |||
| grad_ys: gradients_wrt_outputs.ToArray(), | |||
| src_graph: _func_graph); | |||
| var captures_from_forward = backwards_graph.external_captures | |||
| .Where(x => x is not EagerTensor && x is not NDArray && x.graph == _func_graph) | |||
| .ToArray(); | |||
| HashSet<Tensor> existing_outputs = new(_func_graph.Outputs); | |||
| foreach(var capture in captures_from_forward) | |||
| { | |||
| if (!_func_graph.Outputs.Contains(capture)) | |||
| if (!existing_outputs.Contains(capture)) | |||
| { | |||
| existing_outputs.Add(capture); | |||
| _func_graph.Outputs.Add(capture); | |||
| } | |||
| } | |||
| backwards_graph.Exit(); | |||
| var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"; | |||
| var backward_function_attr = new Dictionary<string, string>(); | |||
| backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; | |||
| gradients_wrt_outputs.append(backwards_graph.internal_captures); | |||
| backwards_graph.Inputs = gradients_wrt_outputs; | |||
| backwards_graph.Outputs = gradients_wrt_inputs; | |||
| backwards_graph.Inputs = gradients_wrt_outputs.Concat(backwards_graph.internal_captures).ToArray(); | |||
| backwards_graph.Outputs.AddRange(gradients_wrt_inputs.Where(x => x is not null)); | |||
| var (forward_function, backward_function) = monomorphic_function_utils._create_forward_backward_with_graph(null, _func_graph, backwards_graph); | |||
| //var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"; | |||
| //var backward_function_attr = new Dictionary<string, string>(); | |||
| //backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; | |||
| var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr); | |||
| //var backward_function = new ConcreteFunction(backwards_graph, | |||
| // monomorphic_function_utils._parse_func_attrs(backward_function_attr)); | |||
| var forward_function_attr = new Dictionary<string, string>(); | |||
| forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name; | |||
| var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph, | |||
| _func_graph.Inputs, _func_graph.Outputs, forward_function_attr); | |||
| //var forward_function_attr = new Dictionary<string, string>(); | |||
| //forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name; | |||
| //var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph, | |||
| // _func_graph.Inputs, _func_graph.Outputs, | |||
| // monomorphic_function_utils._parse_func_attrs(forward_function_attr)); | |||
| return (forward_function, _func_graph, backward_function, null, 0); | |||
| } | |||
| @@ -0,0 +1,84 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Security.Cryptography.X509Certificates; | |||
| using System.Text; | |||
| using Tensorflow.Graphs; | |||
| namespace Tensorflow.Functions | |||
| { | |||
| public class TracingCompiler | |||
| { | |||
| Func<Tensor[], Tensor[]> _csharp_function; | |||
| //FunctionSpec _function_spec; | |||
| internal string _name; | |||
| bool _autograph; | |||
| Dictionary<string, ConcreteFunction> _function_cache; | |||
| Dictionary<string, AttrValue> _function_attributes; | |||
| int _tracing_count; | |||
| public TracingCompiler(Func<Tensor[], Tensor[]> csharp_function, string name, object? input_signatures = null, | |||
| Dictionary<string, AttrValue> attributes = null, bool autograph = true, object? autograph_options = null, | |||
| bool reduce_retracing = false, bool capture_by_value = false) | |||
| { | |||
| _csharp_function = csharp_function; | |||
| bool pure_function = attributes is not null && attributes.Count > 0 && attributes.ContainsKey(monomorphic_function_utils.IMPLEMENTS_ATTRIBUTE_NAME); | |||
| _name = name; | |||
| _autograph = autograph; | |||
| _function_attributes = attributes ?? new Dictionary<string, AttrValue>(); | |||
| _function_cache = new Dictionary<string, ConcreteFunction>(); | |||
| _tracing_count = 0; | |||
| } | |||
| public Tensor[] Apply(Tensor[] inputs) | |||
| { | |||
| // TODO(Rinne): add lock here. | |||
| var (concrete_function, filtered_flat_args) = _maybe_define_function(inputs); | |||
| return concrete_function.CallFlat(filtered_flat_args, concrete_function.CapturedInputs); | |||
| } | |||
| internal ConcreteFunction _get_concrete_function_internal_garbage_collected(Tensor[] args) | |||
| { | |||
| var (concrete_function, _) = _maybe_define_concrete_function(args); | |||
| return concrete_function; | |||
| } | |||
| private (ConcreteFunction, Tensor[]) _maybe_define_concrete_function(Tensor[] args) | |||
| { | |||
| return _maybe_define_function(args); | |||
| } | |||
| private (ConcreteFunction, Tensor[]) _maybe_define_function(Tensor[] args) | |||
| { | |||
| var lookup_func_key = male_cache_key(args); | |||
| if(_function_cache.TryGetValue(lookup_func_key, out var concrete_function)) | |||
| { | |||
| return (concrete_function, args); | |||
| } | |||
| concrete_function = _create_concrete_function(args); | |||
| _function_cache[lookup_func_key] = concrete_function; | |||
| return (concrete_function, args); | |||
| } | |||
| private ConcreteFunction _create_concrete_function(Tensor[] args) | |||
| { | |||
| _tracing_count++; | |||
| int arglen = args.Length; | |||
| var concrete_function = new ConcreteFunction(FuncGraph.func_graph_from_func( | |||
| _name, x => _csharp_function(x.Where(y => y is Tensor).Select(y => (Tensor)y).ToArray()), | |||
| args, new Dictionary<string, object>(), autograph: _autograph | |||
| ), _function_attributes); | |||
| return concrete_function; | |||
| } | |||
| private static string male_cache_key(Tensor[] inputs) | |||
| { | |||
| string res = ""; | |||
| foreach (var input in inputs) | |||
| { | |||
| res += $"{input.name}_{input.Id}"; | |||
| } | |||
| return res; | |||
| } | |||
| } | |||
| } | |||
| @@ -16,6 +16,7 @@ | |||
| using System; | |||
| using System.Runtime.InteropServices; | |||
| using Tensorflow.Functions; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -54,6 +55,9 @@ namespace Tensorflow | |||
| public static extern IntPtr TF_FunctionName(SafeFuncGraphHandle func); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_GraphCopyFunction(SafeGraphHandle g, SafeFuncGraphHandle func, IntPtr grad, SafeStatusHandle status); | |||
| public static extern void TF_GraphCopyFunction(SafeGraphHandle g, SafeFuncGraphHandle func, SafeFuncGraphHandle grad, SafeStatusHandle status); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TF_GraphGetFunctions(SafeGraphHandle g, IntPtr[] funcs, int max_func, SafeStatusHandle status); | |||
| } | |||
| } | |||
| @@ -0,0 +1,50 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Framework; | |||
| using Tensorflow.Framework.Models; | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow.Functions | |||
| { | |||
| internal static class composite_tensor_utils | |||
| { | |||
| public static List<object> flatten_with_variables(object inputs) | |||
| { | |||
| List<object> flat_inputs = new(); | |||
| foreach(var value in nest.flatten(inputs)) | |||
| { | |||
| if(value is CompositeTensor && !resource_variable_ops.is_resource_variable(value)) | |||
| { | |||
| throw new NotImplementedException("The composite tensor has not been fully supported."); | |||
| } | |||
| else | |||
| { | |||
| flat_inputs.Add(value); | |||
| } | |||
| } | |||
| return flat_inputs; | |||
| } | |||
| public static List<object> flatten_with_variables_or_variable_specs(object arg) | |||
| { | |||
| List<object> flat_inputs = new(); | |||
| foreach(var value in nest.flatten(arg)) | |||
| { | |||
| if(value is CompositeTensor && !resource_variable_ops.is_resource_variable(value)) | |||
| { | |||
| throw new NotImplementedException("The composite tensor has not been fully supported."); | |||
| } | |||
| // TODO(Rinne): deal with `VariableSpec`. | |||
| else if(value is TypeSpec type_spec && value is not TensorSpec) | |||
| { | |||
| throw new NotImplementedException("The TypeSpec has not been fully supported."); | |||
| } | |||
| else | |||
| { | |||
| flat_inputs.Add(value); | |||
| } | |||
| } | |||
| return flat_inputs; | |||
| } | |||
| } | |||
| } | |||
| @@ -34,11 +34,10 @@ namespace Tensorflow.Functions | |||
| "https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
| } | |||
| }); | |||
| var bound_variables = inputs.TakeWhile(obj => obj is IVariableV1); | |||
| var bound_variables = inputs.Where(obj => obj is IVariableV1).Select(x => (IVariableV1)x); | |||
| List<Tensor> captured_inputs_list = new(); | |||
| // TODO(Rinne): concrete_function.set_variables(bound_variables) | |||
| concrete_function.set_variables(bound_variables); | |||
| if (bound_inputs is not null) | |||
| { | |||
| @@ -54,8 +53,14 @@ namespace Tensorflow.Functions | |||
| concrete_function.func_graph.replace_capture(bound_input, internal_capture); | |||
| if(internal_capture.dtype == dtypes.resource) | |||
| { | |||
| // skip the check of variable. | |||
| handle_data_util.copy_handle_data(bound_input, internal_capture); | |||
| if (resource_variable_ops.is_resource_variable(bound_input)) | |||
| { | |||
| handle_data_util.copy_handle_data(bound_input.Handle, internal_capture); | |||
| } | |||
| else | |||
| { | |||
| handle_data_util.copy_handle_data(bound_input, internal_capture); | |||
| } | |||
| } | |||
| concrete_function.func_graph.capture(bound_input); | |||
| } | |||
| @@ -1,20 +1,137 @@ | |||
| using System; | |||
| using Google.Protobuf; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Framework.Models; | |||
| using Tensorflow.Gradients; | |||
| using Tensorflow.Graphs; | |||
| using Tensorflow.Common.Extensions; | |||
| using Tensorflow.Operations; | |||
| using Tensorflow.Framework; | |||
| using static Tensorflow.Binding; | |||
| using System.Diagnostics; | |||
| namespace Tensorflow.Functions | |||
| { | |||
| public class DelayedRewriteGradientFunctions: TapeGradientFunctions | |||
| internal static class monomorphic_function_utils | |||
| { | |||
| internal static string _FORWARD_PREFIX = "__forward_"; | |||
| internal static string _BACKWARD_PREFIX = "__backward_"; | |||
| internal static string _INFERENCE_PREFIX = "__inference_"; | |||
| internal static string IMPLEMENTS_ATTRIBUTE_NAME = "_implements"; | |||
| internal static string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"; | |||
| internal static string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"; | |||
| public static string _inference_name(string name) | |||
| { | |||
| return $"{_INFERENCE_PREFIX}{name}_{ops.uid()}"; | |||
| } | |||
| public static string _forward_name(string name) | |||
| { | |||
| return $"{_FORWARD_PREFIX}{name}_{ops.uid()}"; | |||
| } | |||
| public static string _backward_name(string name) | |||
| { | |||
| return $"{_BACKWARD_PREFIX}{name}_{ops.uid()}"; | |||
| } | |||
| public static (EagerDefinedFunction, ConcreteFunction) _create_forward_backward_with_graph(Dictionary<string, AttrValue> attrs, | |||
| FuncGraph forward_graph, FuncGraph backwards_graph) | |||
| { | |||
| string forward_function_name = _forward_name(forward_graph.Name); | |||
| Dictionary<string, AttrValue> common_attributes; | |||
| if(attrs is null) | |||
| { | |||
| common_attributes = new Dictionary<string, AttrValue>(); | |||
| } | |||
| else | |||
| { | |||
| common_attributes = new Dictionary<string, AttrValue>(attrs); | |||
| } | |||
| if (common_attributes.ContainsKey(IMPLEMENTS_ATTRIBUTE_NAME)) | |||
| { | |||
| common_attributes.Remove(IMPLEMENTS_ATTRIBUTE_NAME); | |||
| } | |||
| var backward_function_attr = _parse_func_attrs(new Dictionary<string, object>() | |||
| { | |||
| {FORWARD_FUNCTION_ATTRIBUTE_NAME, forward_function_name } | |||
| }); | |||
| backward_function_attr.Update(common_attributes); | |||
| var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr); | |||
| var forward_function_attr = _parse_func_attrs(new Dictionary<string, object>() | |||
| { | |||
| {BACKWARD_FUNCTION_ATTRIBUTE_NAME, backward_function.Name } | |||
| }); | |||
| forward_function_attr.Update(common_attributes); | |||
| var forward_function = new EagerDefinedFunction(forward_function_name, forward_graph, | |||
| forward_graph.Inputs, forward_graph.Outputs, forward_function_attr); | |||
| return (forward_function, backward_function); | |||
| } | |||
| public static Dictionary<string, AttrValue> _parse_func_attrs(Dictionary<string, object> attributes) | |||
| { | |||
| Dictionary<string, AttrValue> attrs = new(); | |||
| foreach(var item in attributes) | |||
| { | |||
| var key = item.Key; | |||
| var value = item.Value; | |||
| if (value is AttrValue attr_value) | |||
| { | |||
| attrs[key] = attr_value; | |||
| } | |||
| else if (value is bool b) | |||
| { | |||
| attrs[key] = new AttrValue() { B = b }; | |||
| } | |||
| else if (value is int i) | |||
| { | |||
| attrs[key] = new AttrValue() { I = i }; | |||
| } | |||
| else if (value is float f) | |||
| { | |||
| attrs[key] = new AttrValue() { F = f }; | |||
| } | |||
| else if(value is string s) | |||
| { | |||
| attrs[key] = new AttrValue() { S = ByteString.CopyFromUtf8(s) }; | |||
| } | |||
| else if (value is byte[] bytes) | |||
| { | |||
| attrs[key] = new AttrValue() { S = ByteString.CopyFrom(bytes) }; | |||
| } | |||
| else | |||
| { | |||
| throw new ValueError($"Attribute {key} must be bool, int, float, string, or " + | |||
| $"AttrValue. Got {value.GetType()}."); | |||
| } | |||
| } | |||
| return attrs; | |||
| } | |||
| public static Dictionary<string, AttrValue> _parse_func_attrs(Dictionary<string, string> attributes) | |||
| { | |||
| Dictionary<string, AttrValue> attrs = new(); | |||
| foreach (var item in attributes) | |||
| { | |||
| var key = item.Key; | |||
| var value = item.Value; | |||
| attrs[key] = new AttrValue() { S = ByteString.CopyFromUtf8(value) }; | |||
| } | |||
| return attrs; | |||
| } | |||
| } | |||
| public class DelayedRewriteGradientFunctions : TapeGradientFunctions | |||
| { | |||
| EagerDefinedFunction _inference_function; | |||
| Dictionary<string, string> _attrs; | |||
| Dictionary<string, AttrValue> _attrs; | |||
| int _num_inference_outputs; | |||
| public DelayedRewriteGradientFunctions(FuncGraph func_graph, Dictionary<string, string> attrs) | |||
| :base(func_graph, false) | |||
| Dictionary<int, (EagerDefinedFunction, ConcreteFunction)> _cached_function_pairs = new(); | |||
| public DelayedRewriteGradientFunctions(FuncGraph func_graph, Dictionary<string, AttrValue> attrs) | |||
| : base(func_graph, false) | |||
| { | |||
| _func_graph= func_graph; | |||
| _inference_function = new EagerDefinedFunction(_inference_name(_func_graph.Name), | |||
| _func_graph = func_graph; | |||
| _inference_function = new EagerDefinedFunction(monomorphic_function_utils._inference_name(_func_graph.Name), | |||
| _func_graph, _func_graph.Inputs, _func_graph.Outputs, attrs); | |||
| _attrs = attrs; | |||
| _num_inference_outputs = _func_graph.Outputs.Length; | |||
| @@ -22,7 +139,7 @@ namespace Tensorflow.Functions | |||
| public override EagerDefinedFunction Forward(Tensors inference_args = null, Tensors input_tangents = null) | |||
| { | |||
| if(input_tangents is not null) | |||
| if (input_tangents is not null) | |||
| { | |||
| throw new InvalidArgumentError($"unexpectedly got forwardprop information in " + | |||
| $"a class that does not support forwardprop."); | |||
| @@ -32,23 +149,134 @@ namespace Tensorflow.Functions | |||
| public override void Record(Tensors flat_outputs, Tensors inference_args) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| throw new NotImplementedException(); | |||
| base.Record(flat_outputs, inference_args); | |||
| var (backward_function, to_record) = _backward(flat_outputs); | |||
| foreach(var tape in tf.GetTapeSet()) | |||
| { | |||
| tape.RecordOperation(_inference_function.Signature.Name, to_record, | |||
| inference_args.Select(t => new TapeTensor(t)).ToArray(), backward_function); | |||
| } | |||
| } | |||
| //private (BackwardFunction, Tensors) _backward(Tensors outputs) | |||
| //{ | |||
| // Tensor[] backward_function(Tensor[] grads, long[] unneeded_gradients) | |||
| // { | |||
| // var call_op = outputs[0].op; | |||
| public (EagerDefinedFunction, ConcreteFunction) forward_backward(int num_doutputs = -2) | |||
| { | |||
| if(num_doutputs == -2) | |||
| { | |||
| num_doutputs = _num_inference_outputs; | |||
| } | |||
| if(_cached_function_pairs.TryGetValue(num_doutputs, out var target)) | |||
| { | |||
| return target; | |||
| } | |||
| var (forward, backward) = _construct_forward_backward(num_doutputs); | |||
| _cached_function_pairs[num_doutputs] = (forward, backward); | |||
| return (forward, backward); | |||
| // } | |||
| //} | |||
| } | |||
| private string _inference_name(string name) | |||
| private (BackwardFunction, Tensors) _backward(Tensors outputs) | |||
| { | |||
| return $"{_INFERENCE_PREFIX}{name}_{ops.uid()}"; | |||
| Tensor[] backward_function(Tensor[] args, long[] unneeded_gradients) | |||
| { | |||
| var call_op = outputs[0].op; | |||
| return _rewrite_forward_and_call_backward(call_op, args); | |||
| } | |||
| return (backward_function, outputs); | |||
| } | |||
| internal Tensor[] _rewrite_forward_and_call_backward(Operation op, params object[] doutputs) | |||
| { | |||
| var (forward_function, backward_function) = forward_backward(doutputs.Length); | |||
| if(backward_function.Outputs is null || backward_function.Outputs.Length == 0) | |||
| { | |||
| return backward_function.FlatStructuredOutputs; | |||
| } | |||
| forward_function.AddToGraph(op.graph); | |||
| op._set_func_attr("f", forward_function.Name); | |||
| op._set_type_list_attr("Tout", forward_function.OutputTypes); | |||
| op._add_outputs(forward_function.OutputTypes.Select(x => x.as_tf_dtype()). | |||
| Skip(op.outputs.Length).ToArray(), forward_function.OutputShapes.Skip(op.outputs.Length).ToArray() | |||
| ); | |||
| for(int i = 0; i < op.outputs.Length; i++) | |||
| { | |||
| var func_graph_output = forward_function._func_graph_outputs[i]; | |||
| handle_data_util.copy_handle_data(func_graph_output, op.outputs[i]); | |||
| } | |||
| var capture_mapping = zip(_func_graph.Outputs.Select(t => ops.tensor_id(t)), op.outputs). | |||
| ToDictionary(x => x.Item1, x => x.Item2); | |||
| var remapped_captures = backward_function.CapturedInputs.Select( | |||
| x => capture_mapping.GetOrDefault(ops.tensor_id(x), x) | |||
| ); | |||
| List<Tensor> cleaned_doutputs = new(); | |||
| foreach(var (doutput, placeholder) in zip(doutputs, _func_graph.Outputs)) | |||
| { | |||
| if (backprop_util.IsTrainable(placeholder)) | |||
| { | |||
| if(doutput is IndexedSlices) | |||
| { | |||
| cleaned_doutputs.Add(ops.convert_to_tensor(doutput)); | |||
| } | |||
| else if(doutput is null) | |||
| { | |||
| cleaned_doutputs.Add(default_gradient.zeros_like(placeholder)); | |||
| } | |||
| else if(doutput is Tensor tensor) | |||
| { | |||
| cleaned_doutputs.Add(tensor); | |||
| } | |||
| else | |||
| { | |||
| throw new ValueError($"Unsupported type {doutput.GetType()} in function _rewrite_forward_and_call_backward"); | |||
| } | |||
| } | |||
| } | |||
| return backward_function.CallFlat(cleaned_doutputs.ToArray(), remapped_captures.ToArray()); | |||
| } | |||
| private (EagerDefinedFunction, ConcreteFunction) _construct_forward_backward(int num_doutputs) | |||
| { | |||
| var trainable_outputs = _func_graph.Outputs.Take(num_doutputs).Where(x => backprop_util.IsTrainable(x)); | |||
| List<TensorSpec> signature = new(); | |||
| foreach(var t in trainable_outputs) | |||
| { | |||
| var (shape, dtype) = default_gradient.shape_and_dtype(t); | |||
| signature.Add(new TensorSpec(shape, dtype)); | |||
| } | |||
| Tensor[] _backprop_function(Tensor[] grad_ys) | |||
| { | |||
| return gradients_util._GradientsHelper(trainable_outputs.ToArray(), _func_graph.Inputs, | |||
| grad_ys, src_graph: _func_graph); | |||
| } | |||
| _func_graph.as_default(); | |||
| FuncGraph backwards_graph = new(monomorphic_function_utils._backward_name(_func_graph.Name)); | |||
| FuncGraph.func_graph_from_func(backwards_graph.Name, x => _backprop_function(x.Select(y => | |||
| { | |||
| Debug.Assert(y is Tensor); | |||
| return (Tensor)y; | |||
| }).ToArray()), new object[0], new Dictionary<string, object>(), signature.ToArray(), backwards_graph); | |||
| var backwards_graph_captures = backwards_graph.external_captures; | |||
| var captures_from_forward = backwards_graph_captures.Where(c => c is not EagerTensor && c.graph == _func_graph); | |||
| HashSet<Tensor> existing_outputs = new HashSet<Tensor>(_func_graph.Outputs); | |||
| foreach(var capture in captures_from_forward) | |||
| { | |||
| if (!existing_outputs.Contains(capture)) | |||
| { | |||
| existing_outputs.Add(capture); | |||
| _func_graph.Outputs.Add(capture); | |||
| } | |||
| } | |||
| var (forward_function, backward_function) = monomorphic_function_utils._create_forward_backward_with_graph( | |||
| _attrs, _func_graph, backwards_graph); | |||
| _func_graph.Exit(); | |||
| return (forward_function, backward_function); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,52 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Gradients | |||
| { | |||
| internal static class default_gradient | |||
| { | |||
| public static (Shape, TF_DataType) shape_and_dtype(Tensor t) | |||
| { | |||
| if(t.dtype == dtypes.resource) | |||
| { | |||
| var handle_data = resource_variable_ops.get_eager_safe_handle_data(t); | |||
| if(handle_data is null || !handle_data.IsSet || handle_data.ShapeAndType.Count != 1) | |||
| { | |||
| throw new ValueError($"Internal error: Tried to take gradients (or similar) " + | |||
| $"of a variable without handle data:\n{t}"); | |||
| } | |||
| return (new Shape(handle_data.ShapeAndType[0].Shape), handle_data.ShapeAndType[0].Dtype.as_tf_dtype()); | |||
| } | |||
| return (t.shape, t.dtype); | |||
| } | |||
| public static Tensor zeros_like(Tensor t) | |||
| { | |||
| if(t.dtype == dtypes.resource) | |||
| { | |||
| var (shape, dtype) = shape_and_dtype(t); | |||
| return array_ops.zeros(shape, dtype); | |||
| } | |||
| else | |||
| { | |||
| return array_ops.zeros_like(t); | |||
| } | |||
| } | |||
| public static TF_DataType get_zeros_dtype(Tensor t) | |||
| { | |||
| if(t.dtype == dtypes.resource) | |||
| { | |||
| var handle_data = resource_variable_ops.get_eager_safe_handle_data(t); | |||
| if(handle_data is null || !handle_data.IsSet || handle_data.ShapeAndType.Count != 1) | |||
| { | |||
| throw new ValueError($"Internal error: Tried to take gradients (or similar) " + | |||
| $"of a variable without handle data:\n{t}"); | |||
| } | |||
| return handle_data.ShapeAndType[0].Dtype.as_tf_dtype(); | |||
| } | |||
| return t.dtype; | |||
| } | |||
| } | |||
| } | |||
| @@ -14,10 +14,15 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using Google.Protobuf; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.Gradients; | |||
| using Tensorflow.Graphs; | |||
| using Tensorflow.Operations; | |||
| using Tensorflow.Operations.ControlFlows; | |||
| using static Tensorflow.Binding; | |||
| @@ -148,7 +153,7 @@ namespace Tensorflow | |||
| Tensor[] in_grads = null; | |||
| Func<Operation, Tensor[], Tensor[]> grad_fn = null; | |||
| var is_partitioned_call = _IsPartitionedCall(op); | |||
| var is_func_call = false; | |||
| var is_func_call = src_graph.IsFunction(op.type) || is_partitioned_call; | |||
| var has_out_grads = out_grads.Exists(x => x != null); | |||
| if (has_out_grads && !stop_ops.Contains(op)) | |||
| { | |||
| @@ -162,14 +167,41 @@ namespace Tensorflow | |||
| { | |||
| if (is_func_call) | |||
| { | |||
| EagerDefinedFunction func_call = null; | |||
| if (is_partitioned_call) | |||
| { | |||
| var func_attr = op.get_attr("f"); | |||
| Debug.Assert(func_attr is NameAttrList); | |||
| var func_name = ((NameAttrList)func_attr).Name; | |||
| func_call = src_graph._get_function(func_name); | |||
| if(func_call is null && src_graph.OuterGraph is not null) | |||
| { | |||
| var graph = src_graph.OuterGraph; | |||
| while(graph is not null) | |||
| { | |||
| func_call = graph._get_function(func_name); | |||
| if(func_call is not null) | |||
| { | |||
| break; | |||
| } | |||
| if(graph.OuterGraph is not null) | |||
| { | |||
| graph = graph.OuterGraph; | |||
| } | |||
| else | |||
| { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| else | |||
| { | |||
| func_call = src_graph._get_function(op.type); | |||
| } | |||
| // skip the following codes: | |||
| // `func_call = getattr(op, "__defun", func_call)` | |||
| grad_fn = func_call.csharp_grad_func; | |||
| } | |||
| else | |||
| { | |||
| @@ -213,6 +245,8 @@ namespace Tensorflow | |||
| } | |||
| else | |||
| { | |||
| in_grads = _MaybeCompile(grad_scope, op, out_grads.Where(x => x != null).Select(x => x[0]).ToArray(), | |||
| null, (x, y) => _SymGrad(x, y)); | |||
| throw new NotImplementedException("lambda: _SymGrad(op, out_grads)"); | |||
| } | |||
| _VerifyGeneratedGradients(in_grads, op); | |||
| @@ -668,6 +702,36 @@ namespace Tensorflow | |||
| dtypes.resource, dtypes.variant}.Contains(dtype); | |||
| } | |||
| public static int PossibleTapeGradientTypes(Tensor[] tensors) | |||
| { | |||
| var tape_set = tf.GetTapeSet(); | |||
| bool some_tape_watching = false; | |||
| if(tape_set is not null && tape_set.Count > 0) | |||
| { | |||
| foreach(var tape in tape_set) | |||
| { | |||
| if (tape.ShouldRecord(tensors)) | |||
| { | |||
| if(tape.Persistent || some_tape_watching) | |||
| { | |||
| return POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER; | |||
| } | |||
| some_tape_watching = true; | |||
| } | |||
| } | |||
| } | |||
| // skip the forward_accumulators. | |||
| if (some_tape_watching) | |||
| { | |||
| return POSSIBLE_GRADIENT_TYPES_FIRST_ORDER; | |||
| } | |||
| else | |||
| { | |||
| return POSSIBLE_GRADIENT_TYPES_NONE; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Return true if op has real gradient. | |||
| /// </summary> | |||
| @@ -688,7 +752,7 @@ namespace Tensorflow | |||
| private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func<Operation, Tensor[], Tensor[]> grad_fn) | |||
| { | |||
| scope = scope.EndsWith("/") ? scope.Substring(0, scope.Length - 1) : scope; | |||
| //scope = scope.TrimEnd('/').Replace('/', '_'); | |||
| return grad_fn(op, out_grads); | |||
| } | |||
| @@ -701,5 +765,28 @@ namespace Tensorflow | |||
| throw new ValueError($"Num gradients {grads.Length} generated for op {op.node_def} do not match num " + | |||
| $"inputs {op.inputs._inputs.Count()}"); | |||
| } | |||
| private static Tensor[] _SymGrad(Operation op, Tensor[] out_grads) | |||
| { | |||
| var f_in = ((Tensor[])op.inputs).Concat(out_grads).ToArray(); | |||
| var f_types = ((Tensor[])op.inputs).Select(x => default_gradient.get_zeros_dtype(x)).ToArray(); | |||
| NameAttrList f = new(); | |||
| if (_IsPartitionedCall(op)) | |||
| { | |||
| var func_attr = op.get_attr("f"); | |||
| Debug.Assert(func_attr is NameAttrList); | |||
| f.Name = ((NameAttrList)func_attr).Name; | |||
| } | |||
| else | |||
| { | |||
| f.Name = op.type; | |||
| } | |||
| foreach(var k in op.node_def.Attr.Keys) | |||
| { | |||
| f.Attr[k] = AttrValue.Parser.ParseFrom(op.node_def.Attr[k].ToByteArray()); | |||
| } | |||
| var in_grads = gen_functional_ops.symbolic_gradient(f_in, f_types, f); | |||
| return in_grads; | |||
| } | |||
| } | |||
| } | |||
| @@ -98,12 +98,23 @@ namespace Tensorflow | |||
| { | |||
| if (op.inputs == null) return null; | |||
| RegisterFromAssembly(); | |||
| var gradient_function = op._gradient_function; | |||
| if(gradient_function is null) | |||
| { | |||
| RegisterFromAssembly(); | |||
| if (!gradientFunctions.ContainsKey(op.type)) | |||
| throw new LookupError($"can't get graident function through get_gradient_function {op.type}"); | |||
| if (!gradientFunctions.ContainsKey(op.type)) | |||
| throw new LookupError($"can't get graident function through get_gradient_function {op.type}"); | |||
| return gradientFunctions[op.type]; | |||
| } | |||
| return gradientFunctions[op.type]; | |||
| Tensor[] wrapped_gradient_function(Operation operation, Tensor[] args) | |||
| { | |||
| return gradient_function(operation, args); | |||
| } | |||
| // TODO(Rinne): check if this needs to be registered. | |||
| return wrapped_gradient_function; | |||
| } | |||
| } | |||
| } | |||
| @@ -1,6 +1,15 @@ | |||
| using Google.Protobuf; | |||
| using System; | |||
| using System.Buffers; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Exceptions; | |||
| using Tensorflow.Framework; | |||
| using Tensorflow.Framework.Models; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.Operations; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Graphs; | |||
| @@ -11,12 +20,65 @@ namespace Tensorflow.Graphs; | |||
| public class FuncGraph : Graph, IDisposable | |||
| { | |||
| internal SafeFuncGraphHandle _func_graph_handle; | |||
| internal HashSet<Tensor> _resource_tensor_inputs; | |||
| internal HashSet<WeakReference<IVariableV1>> _watched_variables; | |||
| internal IEnumerable<WeakReference<IVariableV1>> _weak_variables; | |||
| internal object[] _structured_outputs; | |||
| internal Dictionary<long, string> _output_names; | |||
| public string FuncName => _graph_key; | |||
| public Tensors Inputs { get; set; } = new Tensors(); | |||
| public Tensors Outputs { get; set; } = new Tensors(); | |||
| public Tensors FlatStructuredOutputs | |||
| { | |||
| get | |||
| { | |||
| List<Tensor> res = new(); | |||
| foreach(var obj in _structured_outputs) | |||
| { | |||
| if(obj is Tensor tensor) | |||
| { | |||
| res.Add(tensor); | |||
| } | |||
| else if(obj is IEnumerable<Tensor> tensors) | |||
| { | |||
| res.AddRange(tensors); | |||
| } | |||
| else | |||
| { | |||
| throw new TypeError("The structured outputs member should be tensor or tensors."); | |||
| } | |||
| } | |||
| return res; | |||
| } | |||
| } | |||
| public string Name { get; set; } | |||
| public Dictionary<string, string> Attrs { get; set; } | |||
| public IEnumerable<IVariableV1> Variables | |||
| { | |||
| get | |||
| { | |||
| return _weak_variables.Select(v => | |||
| { | |||
| if (v.TryGetTarget(out var target)) | |||
| { | |||
| return target; | |||
| } | |||
| else | |||
| { | |||
| throw new AssertionError("Called a function referencing variables which have been deleted. " + | |||
| "This likely means that function-local variables were created and " + | |||
| "not referenced elsewhere in the program. This is generally a " + | |||
| "mistake; consider storing variables in an object attribute on first call."); | |||
| } | |||
| }); | |||
| } | |||
| internal set | |||
| { | |||
| _weak_variables = value.Select(x => new WeakReference<IVariableV1>(x)); | |||
| } | |||
| } | |||
| public IEnumerable<IVariableV1> TrainableVariables => Variables.Where(v => v.Trainable); | |||
| public Dictionary<string, AttrValue> Attrs { get; set; } | |||
| Dictionary<long, (Tensor, Tensor)> _captures | |||
| = new Dictionary<long, (Tensor, Tensor)>(); | |||
| @@ -42,9 +104,12 @@ public class FuncGraph : Graph, IDisposable | |||
| outer_graph = outer_graph.OuterGraph; | |||
| _graph_key = Name = name; | |||
| building_function = true; | |||
| _weak_variables = new List<WeakReference<IVariableV1>>(); | |||
| _resource_tensor_inputs = new HashSet<Tensor>(); | |||
| _watched_variables = new HashSet<WeakReference<IVariableV1>>(); | |||
| } | |||
| public FuncGraph(SafeGraphHandle handle, string name, Dictionary<string, string> attrs) : base() | |||
| public FuncGraph(SafeGraphHandle handle, string name, Dictionary<string, AttrValue> attrs) : base() | |||
| { | |||
| outer_graph = ops.get_default_graph(); | |||
| while (outer_graph.building_function) | |||
| @@ -55,6 +120,9 @@ public class FuncGraph : Graph, IDisposable | |||
| // Will to test if FuncGraph has memory leak | |||
| // c_api.TF_DeleteGraph(_handle); | |||
| _handle = handle; | |||
| _weak_variables = new List<WeakReference<IVariableV1>>(); | |||
| _resource_tensor_inputs = new HashSet<Tensor>(); | |||
| _watched_variables = new HashSet<WeakReference<IVariableV1>>(); | |||
| } | |||
| public void replace_capture(Tensor tensor, Tensor placeholder) | |||
| @@ -62,14 +130,14 @@ public class FuncGraph : Graph, IDisposable | |||
| _captures[tensor.Id] = (tensor, placeholder); | |||
| } | |||
| public void ToGraph(Operation[] opers, | |||
| public unsafe void ToGraph(Operation[] opers, | |||
| Tensor[] inputs, Tensor[] outputs, | |||
| string[] output_names) | |||
| { | |||
| var status = new Status(); | |||
| if (output_names != null && output_names.Length == 0) | |||
| if (output_names is null) | |||
| { | |||
| output_names = null; | |||
| output_names = new string[0]; | |||
| }; | |||
| _func_graph_handle = c_api.TF_GraphToFunction(_handle, | |||
| @@ -81,7 +149,7 @@ public class FuncGraph : Graph, IDisposable | |||
| inputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | |||
| outputs.Length, | |||
| outputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | |||
| output_names, | |||
| output_names.Length != outputs.Length ? null : output_names, | |||
| IntPtr.Zero, | |||
| null, | |||
| status); | |||
| @@ -211,6 +279,19 @@ public class FuncGraph : Graph, IDisposable | |||
| Inputs.Add(placeholder); | |||
| } | |||
| Tensor pop_capture(Tensor tensor) | |||
| { | |||
| if(_captures.TryGetValue(tensor.Id, out var capture)) | |||
| { | |||
| _captures.Remove(tensor.Id); | |||
| return capture.Item2; | |||
| } | |||
| else | |||
| { | |||
| return null; | |||
| } | |||
| } | |||
| Tensor _create_substitute_placeholder(Tensor value, | |||
| string name = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| @@ -234,10 +315,7 @@ public class FuncGraph : Graph, IDisposable | |||
| foreach (var (_name, attr_value) in enumerate(Attrs)) | |||
| { | |||
| var serialized = new AttrValue | |||
| { | |||
| S = ByteString.CopyFromUtf8(attr_value) | |||
| }.ToByteArray(); | |||
| var serialized = attr_value.ToByteArray(); | |||
| c_api.TF_FunctionSetAttrValueProto(_func_graph_handle, _name, serialized, serialized.Length, tf.Status); | |||
| tf.Status.Check(true); | |||
| } | |||
| @@ -260,4 +338,261 @@ public class FuncGraph : Graph, IDisposable | |||
| { | |||
| c_api.TFE_ContextRemoveFunction(tf.Context, _graph_key, tf.Status); | |||
| } | |||
| public static FuncGraph func_graph_from_func(string name, Func<object[], object[]> func, | |||
| object[] args, Dictionary<string, object> kwargs, TensorSpec[] signature = null, | |||
| FuncGraph func_graph = null, bool autograph = false, object autograph_options = null, | |||
| bool add_control_dependencies = true, string[] arg_names = null, | |||
| Tensor op_return_value = null, bool capture_by_value = false, | |||
| bool acd_record_initial_resource_uses = false) | |||
| { | |||
| if(func_graph is null) | |||
| { | |||
| func_graph = new FuncGraph(name); | |||
| } | |||
| // TODO(Rinne): deal with control dependencies. | |||
| func_graph.as_default(); | |||
| var current_scope = variable_scope.get_variable_scope(); | |||
| var default_use_resource = current_scope.use_resource; | |||
| current_scope.use_resource = true; | |||
| if(signature is not null) | |||
| { | |||
| args = signature; | |||
| kwargs = new Dictionary<string, object>(); | |||
| } | |||
| var func_args = _get_defun_inputs_from_args(args, arg_names); | |||
| var func_kwargs = _get_defun_inputs_from_kwargs(kwargs); | |||
| if(func_kwargs is not null && func_kwargs.Count > 0) | |||
| { | |||
| throw new NotImplementedException("The keyword args has not been supported in `func_graph_from_func`."); | |||
| } | |||
| foreach(var arg in nest.flatten<object>(new object[] { func_args, func_kwargs })) | |||
| { | |||
| if(arg is Tensor tensor && tensor.dtype == dtypes.resource) | |||
| { | |||
| func_graph._resource_tensor_inputs.Add(tensor); | |||
| } | |||
| else if (arg is ResourceVariable variable) | |||
| { | |||
| func_graph._resource_tensor_inputs.Add(variable.Handle); | |||
| } | |||
| } | |||
| // skip the assignment of `func_graph.structured_input_signature`. | |||
| var flat_func_args = nest.flatten(func_args as object); | |||
| var flat_func_kwargs = nest.flatten(func_kwargs as object); | |||
| func_graph.Inputs = new Tensors(flat_func_args.concat(flat_func_kwargs) | |||
| .Where(x => x is Tensor).Select(x => (Tensor)x)); | |||
| //var func_args_before = nest.pack_sequence_as(func_args, flat_func_args, true); | |||
| //var func_kwargs_before = nest.pack_sequence_as(func_kwargs, flat_func_kwargs, true); | |||
| Tensor convert(object x) | |||
| { | |||
| if (x is null) return null; | |||
| Tensor res = null; | |||
| if(op_return_value is not null && x is Operation) | |||
| { | |||
| tf_with(ops.control_dependencies(new object[] { x }), _ => | |||
| { | |||
| res = array_ops.identity(op_return_value); | |||
| }); | |||
| } | |||
| else if(x is not TensorArray) | |||
| { | |||
| Debug.Assert(x is Tensor); | |||
| res = ops.convert_to_tensor_or_composite(x as Tensor); | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException($"The `TensorArray` is not supported here currently."); | |||
| } | |||
| if (add_control_dependencies) | |||
| { | |||
| // TODO(Rinne): `x = deps_ctx.mark_as_return(x)`. | |||
| } | |||
| return res; | |||
| } | |||
| if (autograph) | |||
| { | |||
| throw new NotImplementedException("The autograph of `func_graph_from_func` has not been supported."); | |||
| } | |||
| var func_outputs = func(func_args); | |||
| func_outputs = variable_utils.convert_variables_to_tensors(func_outputs); | |||
| func_outputs = func_outputs.Select(x => convert(x)).ToArray(); | |||
| // TODO(Rinne): `check_func_mutation`. | |||
| current_scope.use_resource = default_use_resource; | |||
| var graph_variables = func_graph._watched_variables.ToList(); | |||
| HashSet<IVariableV1> arg_variables = new HashSet<IVariableV1>(); | |||
| List<Tensor> inputs = new(); | |||
| foreach(var arg in composite_tensor_utils.flatten_with_variables(func_args)) | |||
| { | |||
| if(arg is BaseResourceVariable variable) | |||
| { | |||
| var resource_placeholder = func_graph.pop_capture(variable.Handle); | |||
| if(resource_placeholder is null) | |||
| { | |||
| continue; | |||
| } | |||
| Debug.Assert(variable is IVariableV1); | |||
| arg_variables.Add(variable as IVariableV1); | |||
| inputs.Add(resource_placeholder); | |||
| } | |||
| else if(arg is Tensor tensor) | |||
| { | |||
| inputs.Add(tensor); | |||
| } | |||
| } | |||
| var variables = graph_variables.Select(v => | |||
| { | |||
| if (v.TryGetTarget(out var target)) | |||
| { | |||
| return target; | |||
| } | |||
| else | |||
| { | |||
| return null; | |||
| } | |||
| }).Where(v => v is not null && !arg_variables.Contains(v)); | |||
| func_graph.Inputs = inputs.Concat(func_graph.internal_captures).ToArray(); | |||
| func_graph._structured_outputs = func_outputs; | |||
| func_graph.Outputs.AddRange(func_graph.FlatStructuredOutputs.Where(x => x is not null) | |||
| .Select(x => func_graph.capture(x))); | |||
| func_graph.Variables = variables; | |||
| func_graph.Exit(); | |||
| if (add_control_dependencies) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| } | |||
| return func_graph; | |||
| } | |||
| private static object[] _get_defun_inputs_from_args(object[] args, string[] names) | |||
| { | |||
| return _get_defun_inputs(args, names, args) as object[]; | |||
| } | |||
| private static Dictionary<string, object> _get_defun_inputs_from_kwargs(Dictionary<string, object> kwargs) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| Debug.Assert(kwargs is null || kwargs.Count == 0); | |||
| return kwargs; | |||
| //string[] names; | |||
| //object[] args; | |||
| //if(kwargs is not null && kwargs.Count > 0) | |||
| //{ | |||
| // var sorted_kwargs = kwargs.OrderBy(x => x.Key); | |||
| // names = sorted_kwargs.Select(x => x.Key).ToArray(); | |||
| // args = sorted_kwargs.Select(x => x.Value).ToArray(); | |||
| //} | |||
| //else | |||
| //{ | |||
| // names = new string[0]; | |||
| // args = new object[0]; | |||
| //} | |||
| //return _get_defun_inputs(args, names, kwargs) as Dictionary<string, object>; | |||
| } | |||
| private static object _get_defun_inputs(object[] args, string[] names, object structured_args) | |||
| { | |||
| List<object> function_inputs = new(); | |||
| if(names is null) | |||
| { | |||
| names = new string[args.Length]; | |||
| } | |||
| foreach(var (arg_value, name) in zip(args, names)) | |||
| { | |||
| foreach(var val in composite_tensor_utils.flatten_with_variables_or_variable_specs(arg_value)) | |||
| { | |||
| function_inputs.Add(_get_defun_input(val, name)); | |||
| } | |||
| } | |||
| return nest.pack_sequence_as(structured_args, nest.flatten<object>(function_inputs), true); | |||
| } | |||
| private static object _get_defun_input(object arg, string name) | |||
| { | |||
| var func_graph = ops.get_default_graph() as FuncGraph; | |||
| Debug.Assert(func_graph is not null); | |||
| if (arg is Tensor tensor) | |||
| { | |||
| Tensor placeholder; | |||
| try | |||
| { | |||
| placeholder = tf.placeholder(tensor.dtype, tensor.shape, name); | |||
| } | |||
| catch (ValueError) | |||
| { | |||
| // TODO(Rinne): Add warning here. | |||
| placeholder = tf.placeholder(tensor.dtype, tensor.shape); | |||
| } | |||
| handle_data_util.copy_handle_data(tensor, placeholder); | |||
| if (name is not null) | |||
| { | |||
| placeholder.op._set_attr("_user_specified_name", new AttrValue() | |||
| { | |||
| S = tf.compat.as_bytes(name) | |||
| }); | |||
| } | |||
| return placeholder; | |||
| } | |||
| else if (arg is TensorSpec spec) | |||
| { | |||
| string requested_name; | |||
| if (!string.IsNullOrEmpty(spec.name)) | |||
| { | |||
| requested_name = spec.name; | |||
| } | |||
| else | |||
| { | |||
| requested_name = name; | |||
| } | |||
| Tensor placeholder; | |||
| try | |||
| { | |||
| placeholder = tf.placeholder(spec.dtype, spec.shape, requested_name); | |||
| } | |||
| catch (ValueError) | |||
| { | |||
| // TODO(Rinne): Add warning here. | |||
| placeholder = tf.placeholder(spec.dtype, spec.shape); | |||
| } | |||
| if (name is not null) | |||
| { | |||
| placeholder.op._set_attr("_user_specified_name", new AttrValue() | |||
| { | |||
| S = tf.compat.as_bytes(requested_name) | |||
| }); | |||
| } | |||
| return placeholder; | |||
| } | |||
| else if (arg is BaseResourceVariable variable) | |||
| { | |||
| var placeholder = func_graph.capture(variable.Handle, name); | |||
| placeholder.op._set_attr("_user_specified_name", new AttrValue() | |||
| { | |||
| S = tf.compat.as_bytes(name) | |||
| }); | |||
| return arg; | |||
| } | |||
| // TODO(Rinne): deal with `VariableSpec`. | |||
| else | |||
| { | |||
| return arg; | |||
| } | |||
| } | |||
| } | |||
| @@ -1,4 +1,6 @@ | |||
| namespace Tensorflow | |||
| using Tensorflow.Graphs; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class Graph | |||
| { | |||
| @@ -6,5 +8,10 @@ | |||
| { | |||
| } | |||
| internal GraphOverrideGradientContext _override_gradient_function(Dictionary<string, Func<Operation, object[], Tensor[]>> gradient_function_map) | |||
| { | |||
| return new GraphOverrideGradientContext(this, gradient_function_map); | |||
| } | |||
| } | |||
| } | |||
| @@ -118,7 +118,7 @@ namespace Tensorflow | |||
| /// <param name="compute_device">(Optional.) If True, device functions will be executed | |||
| /// to compute the device property of the Operation.</param> | |||
| /// <returns>An `Operation` object.</returns> | |||
| public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true) | |||
| public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true, OperationDescription desc = null) | |||
| { | |||
| var ret = new Operation(c_op, this); | |||
| _add_op(ret); | |||
| @@ -21,6 +21,7 @@ using System.Collections.Specialized; | |||
| using System.Linq; | |||
| using Tensorflow.Framework; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.Common.Extensions; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| @@ -88,6 +89,7 @@ namespace Tensorflow | |||
| private List<Operation> _unfetchable_ops = new List<Operation>(); | |||
| private List<Tensor> _unfeedable_tensors = new List<Tensor>(); | |||
| private Dictionary<string, EagerDefinedFunction> _functions = new(); | |||
| internal Dictionary<string, Func<Operation, object[], Tensor[]>> _gradient_function_map = new(); | |||
| private VersionDef _graph_def_versions = new VersionDef() | |||
| { | |||
| Producer = versions.GRAPH_DEF_VERSION, | |||
| @@ -161,13 +163,30 @@ namespace Tensorflow | |||
| return _functions.ContainsKey(tf.compat.as_str(name)); | |||
| } | |||
| public void AddFunction(EagerDefinedFunction function) | |||
| internal void AddFunction(EagerDefinedFunction function) | |||
| { | |||
| _check_not_finalized(); | |||
| var name = function.Name; | |||
| if(function._grad_func_name is not null && function.csharp_grad_func is not null) | |||
| { | |||
| throw new ValueError($"Gradient defined twice for function {name}"); | |||
| } | |||
| // TODO(Rinne): deal with c_graph | |||
| var c_graph = this.c_graph; | |||
| var func = function._c_func.Get(); | |||
| Status status = new(); | |||
| if (function._grad_func is not null) | |||
| { | |||
| var gradient = function._grad_func._c_func.Get(); | |||
| c_api.TF_GraphCopyFunction(c_graph, func, gradient, status); | |||
| status.Check(true); | |||
| } | |||
| else | |||
| { | |||
| c_api.TF_GraphCopyFunction(c_graph, func, new SafeFuncGraphHandle(IntPtr.Zero), status); | |||
| status.Check(true); | |||
| } | |||
| _functions[tf.compat.as_str(name)] = function; | |||
| @@ -332,6 +351,9 @@ namespace Tensorflow | |||
| private void _create_op_helper(Operation op, bool compute_device = true) | |||
| { | |||
| // high priority | |||
| // TODO(Rinne): complete the implementation. | |||
| op._gradient_function = _gradient_function_map.GetOrDefault(op.type, null); | |||
| _record_op_seen_by_control_dependencies(op); | |||
| } | |||
| @@ -548,6 +570,11 @@ namespace Tensorflow | |||
| ops.pop_graph(); | |||
| } | |||
| internal EagerDefinedFunction _get_function(string name) | |||
| { | |||
| return _functions.GetOrDefault(name, null); | |||
| } | |||
| string debugString = string.Empty; | |||
| public override string ToString() | |||
| { | |||
| @@ -0,0 +1,37 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Text; | |||
| namespace Tensorflow.Graphs | |||
| { | |||
| internal class GraphOverrideGradientContext: ITensorFlowObject | |||
| { | |||
| Graph _graph; | |||
| Dictionary<string, Func<Operation, object[], Tensor[]>> _new_gradient_function_map; | |||
| public GraphOverrideGradientContext(Graph graph, | |||
| Dictionary<string, Func<Operation, object[], Tensor[]>> new_gradient_function_map) | |||
| { | |||
| _graph = graph; | |||
| _new_gradient_function_map = new_gradient_function_map; | |||
| } | |||
| [DebuggerStepThrough] | |||
| public void __enter__() | |||
| { | |||
| Debug.Assert(_graph._gradient_function_map.Count == 0); | |||
| _graph._gradient_function_map = _new_gradient_function_map; | |||
| } | |||
| [DebuggerStepThrough] | |||
| public void __exit__() | |||
| { | |||
| _graph._gradient_function_map = new Dictionary<string, Func<Operation, object[], Tensor[]>>(); | |||
| } | |||
| public void Dispose() | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -20,6 +20,9 @@ using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| using Google.Protobuf; | |||
| using Google.Protobuf.WellKnownTypes; | |||
| using System.Diagnostics; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -47,6 +50,8 @@ namespace Tensorflow | |||
| private readonly Graph _graph; | |||
| internal Func<Operation, object[], Tensor[]> _gradient_function; | |||
| public string type => OpType; | |||
| public Graph graph => _graph; | |||
| @@ -61,7 +66,7 @@ namespace Tensorflow | |||
| public string Device => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | |||
| // OperationDescription _opDesc; | |||
| //private OperationDescription _op_desc; | |||
| public NodeDef node_def => GetNodeDef(); | |||
| @@ -216,21 +221,19 @@ namespace Tensorflow | |||
| var x = AttrValue.Parser.ParseFrom(buf.ToArray()); | |||
| string oneof_value = x.ValueCase.ToString(); | |||
| if (string.IsNullOrEmpty(oneof_value)) | |||
| return null; | |||
| var oneof_value = x.ValueCase; | |||
| if (oneof_value == AttrValue.ValueOneofCase.None) | |||
| return new object[0]; | |||
| switch (oneof_value.ToLower()) | |||
| if(oneof_value == AttrValue.ValueOneofCase.List) | |||
| { | |||
| case "list": | |||
| throw new NotImplementedException($"Unsupported field type in {oneof_value}"); | |||
| case "type": | |||
| return x.Type; | |||
| case "s": | |||
| return x.S.ToStringUtf8(); | |||
| default: | |||
| return x.GetType().GetProperty(oneof_value).GetValue(x); | |||
| throw new NotImplementedException($"Unsupported field type in {oneof_value}"); | |||
| } | |||
| if(oneof_value == AttrValue.ValueOneofCase.Type) | |||
| { | |||
| return dtypes.as_tf_dtype(x.Type); | |||
| } | |||
| return ProtoUtils.GetSingleAttrValue(x, oneof_value); | |||
| } | |||
| public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s) | |||
| @@ -309,5 +312,83 @@ namespace Tensorflow | |||
| } | |||
| public NDArray numpy() => throw new NotImplementedException(""); | |||
| internal void _add_outputs(TF_DataType[] types, Shape[] shapes) | |||
| { | |||
| Debug.Assert(types.Length == shapes.Length); | |||
| int orig_num_outputs = this.outputs.Length; | |||
| //var new_outputs = new List<Tensor>(_outputs); | |||
| var old_outputs = _outputs; | |||
| _outputs = new Tensor[orig_num_outputs + types.Length]; | |||
| for(int i = 0; i < orig_num_outputs; i++) | |||
| { | |||
| _outputs[i] = old_outputs[i]; | |||
| } | |||
| // Since the `_outputs` is defined as `Array`, when we add new output, we | |||
| // have to create a new array, which brings some performance concerns. | |||
| // In the future maybe the type of `outputs` should be reconsidered. | |||
| for(int i = 0; i < types.Length; i++) | |||
| { | |||
| var t = new Tensor(this, orig_num_outputs + 1, types[i]); | |||
| _outputs[i] = t; | |||
| //t = tf.ensure_shape(t, shapes[i]); | |||
| t.shape = shapes[i]; | |||
| //new_outputs.Add(t); | |||
| } | |||
| //_outputs = new_outputs.ToArray(); | |||
| } | |||
| internal void _set_func_attr(string attr_name, string func_name) | |||
| { | |||
| var func = new NameAttrList() { Name = func_name }; | |||
| _set_attr(attr_name, new AttrValue() { Func = func }); | |||
| } | |||
| internal void _set_type_list_attr(string attr_name, DataType[] types) | |||
| { | |||
| if(types is null || types.Length == 0) | |||
| { | |||
| return; | |||
| } | |||
| var type_list = new AttrValue.Types.ListValue(); | |||
| type_list.Type.AddRange(types); | |||
| _set_attr(attr_name, new AttrValue() { List = type_list }); | |||
| } | |||
| internal void _set_attr(string attr_name, AttrValue attr_value) | |||
| { | |||
| var buffer = new Buffer(attr_value.ToByteArray()); | |||
| try | |||
| { | |||
| _set_attr_with_buf(attr_name, buffer); | |||
| } | |||
| finally | |||
| { | |||
| buffer.Release(); | |||
| } | |||
| } | |||
| internal void _set_attr_with_buf(string attr_name, Buffer attr_buf) | |||
| { | |||
| //if(_op_desc is null) | |||
| //{ | |||
| // //var new_node_def = NodeDef.Parser.ParseFrom(node_def.ToByteArray()); | |||
| // //new_node_def.Name += "_temp"; | |||
| // //var op = new Operation(new_node_def, graph, inputs, _output_types, control_inputs, _input_types); | |||
| // //Status status = new(); | |||
| // //c_api.TF_SetAttrBool(op._op_desc, "trainable", true); | |||
| // ////c_api.TF_SetAttrValueProto(op._op_desc, attr_name, attr_buf.ToArray(), attr_buf.Length, status); | |||
| // //status.Check(true); | |||
| // // TODO(Rinne): deal with it. Give a warning or make the Operation always contains `op_desc`. | |||
| //} | |||
| //else | |||
| //{ | |||
| // //Status status = new(); | |||
| // //c_api.TF_SetAttrValueProto(_op_desc, attr_name, attr_buf.ToArray(), attr_buf.Length, status); | |||
| // //status.Check(true); | |||
| //} | |||
| } | |||
| } | |||
| } | |||
| @@ -208,9 +208,9 @@ namespace Tensorflow | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, SafeStatusHandle status); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data); | |||
| //[DllImport(TensorFlowLibName)] | |||
| //public static extern IntPtr GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||
| //[DllImport(TensorFlowLibName)] | |||
| //public static extern void SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data); | |||
| } | |||
| } | |||
| @@ -39,7 +39,7 @@ namespace Tensorflow | |||
| if (config is null) | |||
| { | |||
| config = function_utils.get_disabled_rewriter_config(); | |||
| config = function_utils.get_disabled_rewriter_config().ToString(); | |||
| } | |||
| if (executor_type is null) | |||
| @@ -79,5 +79,50 @@ namespace Tensorflow.Operations | |||
| }; | |||
| } | |||
| public static Tensor[] symbolic_gradient(Tensor[] input, TF_DataType[] Tout, NameAttrList f, string name = null) | |||
| { | |||
| var ctx = tf.Context; | |||
| if (ctx.executing_eagerly()) | |||
| { | |||
| try | |||
| { | |||
| var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo( | |||
| "SymbolicGradient", name, input, Tout, f)); | |||
| return _result; | |||
| } | |||
| catch (Exception) | |||
| { | |||
| } | |||
| try | |||
| { | |||
| return symbolic_gradient_eager_fallback(input, Tout, f, name, ctx); | |||
| } | |||
| catch (Exception) | |||
| { | |||
| } | |||
| } | |||
| var op = tf.OpDefLib._apply_op_helper("SymbolicGradient", name, new object[] { input, Tout, f }); | |||
| var result = op.outputs; | |||
| if (execute.must_record_gradient()) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| return result; | |||
| } | |||
| public static Tensor[] symbolic_gradient_eager_fallback(Tensor[] input, TF_DataType[] Tout, NameAttrList f, string name, Context ctx) | |||
| { | |||
| object[] attrs = new object[] { "Tin", input, "Tout", Tout, "f", f }; | |||
| var result = execute.executes("SymbolicGradient", Tout.Length, input, attrs, ctx, name); | |||
| if (execute.must_record_gradient()) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| return result; | |||
| } | |||
| } | |||
| } | |||
| @@ -10050,13 +10050,51 @@ namespace Tensorflow.Operations | |||
| /// </remarks> | |||
| public static Tensor ensure_shape(Tensor input, Shape shape, string name = "EnsureShape") | |||
| { | |||
| var ctx = tf.Context; | |||
| if (ctx.executing_eagerly()) | |||
| { | |||
| try | |||
| { | |||
| var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("EnsureShape", name, input, shape)); | |||
| return _result[0]; | |||
| } | |||
| catch (Exception) | |||
| { | |||
| Console.WriteLine(); | |||
| } | |||
| try | |||
| { | |||
| return ensure_shape_eager_fallback(input, shape, name, ctx); | |||
| } | |||
| catch (Exception) | |||
| { | |||
| Console.WriteLine(); | |||
| } | |||
| } | |||
| var dict = new Dictionary<string, object>(); | |||
| dict["input"] = input; | |||
| dict["shape"] = shape; | |||
| var op = tf.OpDefLib._apply_op_helper("EnsureShape", name: name, keywords: dict); | |||
| if (execute.must_record_gradient()) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| return op.output; | |||
| } | |||
| public static Tensor ensure_shape_eager_fallback(Tensor input, Shape shape, string name, Context ctx) | |||
| { | |||
| object[] attrs = new object[4] { "shape", shape, "T", input.dtype.as_datatype_enum() }; | |||
| var _result = execute.executes("EnsureShape", 1, new Tensor[] { input }, | |||
| attrs, ctx, name); | |||
| if (execute.must_record_gradient()) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| return _result[0]; | |||
| } | |||
| /// <summary> | |||
| /// Creates or finds a child frame, and makes <c>data</c> available to the child frame. | |||
| /// </summary> | |||
| @@ -52,5 +52,7 @@ namespace Tensorflow.Operations | |||
| // TODO(Rinne): enable it. (currently the internal c api cannot be invoked.) | |||
| //c_api.SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), handle_data.ToByteArray()); | |||
| } | |||
| public static HandleData get_resource_handle_data(Tensor graph_op) => ops.get_resource_handle_data(graph_op); | |||
| } | |||
| } | |||
| @@ -24,6 +24,7 @@ using static Tensorflow.CppShapeInferenceResult.Types; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow.Operations; | |||
| using System.Buffers; | |||
| using Tensorflow.Eager; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -41,12 +42,7 @@ namespace Tensorflow | |||
| name: name); | |||
| } | |||
| public static bool is_resource_variable(IVariableV1 var) | |||
| { | |||
| return var is BaseResourceVariable; | |||
| } | |||
| public static bool is_resource_variable(Trackable var) | |||
| public static bool is_resource_variable(object var) | |||
| { | |||
| return var is BaseResourceVariable; | |||
| } | |||
| @@ -138,10 +134,27 @@ namespace Tensorflow | |||
| /// <param name="graph_mode"></param> | |||
| internal unsafe static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode) | |||
| { | |||
| tensor.HandleData = handle_data; | |||
| if (!graph_mode) | |||
| return; | |||
| var size = handle_data.ShapeAndType.Count; | |||
| var shapes = new IntPtr[size]; | |||
| var types = new DataType[size]; | |||
| var ranks = new int[size]; | |||
| for (int i = 0; i < size; i++) | |||
| { | |||
| var shapeAndType = handle_data.ShapeAndType[i]; | |||
| types[i] = shapeAndType.Dtype; | |||
| ranks[i] = shapeAndType.Shape.UnknownRank ? -1 : shapeAndType.Shape.Dim.Count; | |||
| var dims = shapeAndType.Shape.Dim.Select(x => x.Size).ToArray(); | |||
| } | |||
| //tensor.HandleData = handle_data; | |||
| //if (!graph_mode) | |||
| // return; | |||
| //var shapes = handle_data.ShapeAndType.Select(x => x.Shape); | |||
| //var types = handle_data.ShapeAndType.Select(x => x.Dtype).ToArray(); | |||
| //var ranks = shapes.Select(s => s.UnknownRank ? -1 : s.Dim.Count).ToArray(); | |||
| @@ -196,24 +209,6 @@ namespace Tensorflow | |||
| throw new NotImplementedException(""); | |||
| } | |||
| private static HandleData get_eager_safe_handle_data(Tensor handle) | |||
| { | |||
| if (handle.Handle == null) | |||
| { | |||
| var data = new HandleData(); | |||
| data.ShapeAndType.Add(new HandleShapeAndType | |||
| { | |||
| Shape = handle.shape.as_shape_proto(), | |||
| Dtype = handle.dtype.as_datatype_enum() | |||
| }); | |||
| return data; | |||
| } | |||
| else | |||
| { | |||
| return HandleData.Parser.ParseFrom(handle.BufferToArray()); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Copies an existing variable to a new graph, with no initializer. | |||
| /// </summary> | |||
| @@ -281,5 +276,31 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| } | |||
| public static HandleData get_eager_safe_handle_data(Tensor handle) | |||
| { | |||
| if (handle.Handle == null) | |||
| { | |||
| var data = new HandleData(); | |||
| data.ShapeAndType.Add(new HandleShapeAndType | |||
| { | |||
| Shape = handle.shape.as_shape_proto(), | |||
| Dtype = handle.dtype.as_datatype_enum() | |||
| }); | |||
| return data; | |||
| } | |||
| else | |||
| { | |||
| return HandleData.Parser.ParseFrom(handle.BufferToArray()); | |||
| } | |||
| //if(handle is EagerTensor) | |||
| //{ | |||
| // return handle.HandleData; | |||
| //} | |||
| //else | |||
| //{ | |||
| // return handle_data_util.get_resource_handle_data(handle); | |||
| //} | |||
| } | |||
| } | |||
| } | |||
| @@ -101,6 +101,7 @@ namespace Tensorflow | |||
| _op = op; | |||
| _value_index = value_index; | |||
| _override_dtype = dtype; | |||
| _tf_output = null; | |||
| _id = ops.uid(); | |||
| } | |||
| @@ -136,9 +136,9 @@ namespace Tensorflow | |||
| protected virtual void SetShapeInternal(Shape value) | |||
| { | |||
| if (value == null) | |||
| c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status); | |||
| c_api.TF_GraphSetTensorShape(op.graph.c_graph, _as_tf_output(), null, -1, tf.Status); | |||
| else | |||
| c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.dims, value.ndim, tf.Status); | |||
| c_api.TF_GraphSetTensorShape(op.graph.c_graph, _as_tf_output(), value.dims, value.ndim, tf.Status); | |||
| } | |||
| public int[] _shape_tuple() | |||
| @@ -177,7 +177,9 @@ namespace Tensorflow | |||
| if (_handle == null) | |||
| { | |||
| var output = _as_tf_output(); | |||
| int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, tf.Status); | |||
| Status status = new(); | |||
| int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, status); | |||
| status.Check(true); | |||
| return ndim; | |||
| } | |||
| @@ -199,7 +201,7 @@ namespace Tensorflow | |||
| public TF_Output _as_tf_output() | |||
| { | |||
| if (!_tf_output.HasValue) | |||
| _tf_output = new TF_Output(op, value_index); | |||
| _tf_output = new TF_Output(op, _value_index); | |||
| return _tf_output.Value; | |||
| } | |||
| @@ -56,7 +56,7 @@ namespace Tensorflow | |||
| public void Add(Tensor tensor) | |||
| => items.Add(tensor); | |||
| public void AddRange(Tensor[] tensors) | |||
| public void AddRange(IEnumerable<Tensor> tensors) | |||
| => items.AddRange(tensors); | |||
| public void Insert(int index, Tensor tensor) | |||
| @@ -12,11 +12,12 @@ namespace Tensorflow.Training.Saving.SavedModel | |||
| { | |||
| public WrapperFunction(ConcreteFunction concrete_function): base(concrete_function.func_graph) | |||
| { | |||
| this.forward_backward = concrete_function.forward_backward; | |||
| this.Outputs = concrete_function.Outputs; | |||
| this.ReturnType = concrete_function.ReturnType; | |||
| this.OutputStructure = concrete_function.OutputStructure; | |||
| this.ArgKeywords = concrete_function.ArgKeywords; | |||
| throw new NotImplementedException(); | |||
| //this.forward_backward = concrete_function.forward_backward; | |||
| //this.Outputs = concrete_function.Outputs; | |||
| //this.ReturnType = concrete_function.ReturnType; | |||
| //this.OutputStructure = concrete_function.OutputStructure; | |||
| //this.ArgKeywords = concrete_function.ArgKeywords; | |||
| } | |||
| } | |||
| } | |||
| @@ -30,6 +30,31 @@ namespace Tensorflow.Training.Saving.SavedModel | |||
| { | |||
| var function_spec = _deserialize_function_spec_as_nonmethod(saved_function.FunctionSpec); | |||
| Tensor[] restored_function_body(Tensor[] inputs) | |||
| { | |||
| if(saved_function.ConcreteFunctions is null || saved_function.ConcreteFunctions.Count == 0) | |||
| { | |||
| throw new ValueError("Found zero restored functions for caller function."); | |||
| } | |||
| foreach(var function_name in saved_function.ConcreteFunctions) | |||
| { | |||
| var function = concrete_functions[function_name]; | |||
| if(function.CapturedInputs.Any(x => x is null)) | |||
| { | |||
| throw new ValueError("Looks like you are trying to run a loaded " + | |||
| "non-Keras model that was trained using tf.distribute.experimental.ParameterServerStrategy " + | |||
| "with variable partitioning, which is not currently supported. Try using Keras to define your model " + | |||
| "if possible."); | |||
| } | |||
| if(_concrete_function_callable_with(function, inputs, false)) | |||
| { | |||
| return _call_concrete_function(function, inputs); | |||
| } | |||
| } | |||
| throw new ValueError("Unexpected runtime behavior, please submit an issue to " + | |||
| "https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
| } | |||
| List<ConcreteFunction> concrete_function_objects = new(); | |||
| foreach(var concrete_function_name in saved_function.ConcreteFunctions) | |||
| { | |||
| @@ -40,17 +65,10 @@ namespace Tensorflow.Training.Saving.SavedModel | |||
| cf._set_function_spec(function_spec); | |||
| } | |||
| foreach(var function_name in saved_function.ConcreteFunctions) | |||
| { | |||
| var function = concrete_functions[function_name]; | |||
| if(_concrete_function_callable_with(function, null, false)) | |||
| { | |||
| return new RestoredFunction(null, function, "function_from_deserialization"); | |||
| } | |||
| } | |||
| return new RestoredFunction(x => x, new ConcreteFunction(x => x, TF_DataType.TF_FLOAT), "function_return_itself"); | |||
| //throw new ValueError("Unexpected runtime behavior, please submit an issue to " + | |||
| // "https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
| var restored_function = new RestoredFunction(restored_function_body, nameof(restored_function_body), | |||
| function_spec, concrete_function_objects); | |||
| return restored_function; | |||
| } | |||
| public static Dictionary<string, ConcreteFunction> load_function_def_library(FunctionDefLibrary library, | |||
| @@ -102,15 +120,17 @@ namespace Tensorflow.Training.Saving.SavedModel | |||
| { | |||
| var orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, new_gradient_op_types); | |||
| if(saved_object_graph is not null && saved_object_graph.ConcreteFunctions.ContainsKey(orig_name)) | |||
| object structured_input_signature = null; | |||
| object structured_outputs = null; | |||
| if (saved_object_graph is not null && saved_object_graph.ConcreteFunctions.ContainsKey(orig_name)) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| //var proto = saved_object_graph.ConcreteFunctions[orig_name]; | |||
| //throw new NotImplementedException(); | |||
| var proto = saved_object_graph.ConcreteFunctions[orig_name]; | |||
| structured_input_signature = nested_structure_coder.decode_proto(proto.CanonicalizedInputSignature); | |||
| structured_outputs = nested_structure_coder.decode_proto(proto.OutputSignature); | |||
| } | |||
| graph.as_default(); | |||
| var func_graph = function_def_lib.function_def_to_graph(fdef, null, null); | |||
| var func_graph = function_def_lib.function_def_to_graph(fdef, structured_input_signature, structured_outputs); | |||
| graph.Exit(); | |||
| _restore_gradient_functions(func_graph, renamed_functions, loaded_gradients); | |||
| @@ -124,7 +144,7 @@ namespace Tensorflow.Training.Saving.SavedModel | |||
| { | |||
| fdef.Attr.Remove("_input_shapes"); | |||
| } | |||
| var func = new ConcreteFunction(func_graph, fdef.Attr.ToDictionary(x => x.Key, x => x.Value.S.ToString())); | |||
| var func = new ConcreteFunction(func_graph, fdef.Attr.ToDictionary(x => x.Key, x => x.Value)); | |||
| if(wrapper_function is not null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| @@ -142,8 +162,7 @@ namespace Tensorflow.Training.Saving.SavedModel | |||
| { | |||
| var gradient_op_type = gradients_to_register[orig_name]; | |||
| loaded_gradients[gradient_op_type] = func; | |||
| // TODO(Rinne): deal with gradient registry. | |||
| //new RegisteredGradient() { RegisteredOpType = gradient_op_type }. | |||
| ops.RegisterGradientFunction(gradient_op_type, _gen_gradient_func(func)); | |||
| } | |||
| } | |||
| return functions; | |||
| @@ -203,6 +222,16 @@ namespace Tensorflow.Training.Saving.SavedModel | |||
| } | |||
| } | |||
| private static Func<Operation, Tensor[], Tensor[]> _gen_gradient_func(ConcreteFunction func) | |||
| { | |||
| return (unused_op, result_grads) => | |||
| { | |||
| result_grads = zip(result_grads, func.func_graph.Inputs) | |||
| .Select((item) => item.Item1 is null ? default_gradient.zeros_like(item.Item2) : item.Item1).ToArray(); | |||
| return func.CallFlat(result_grads, func.CapturedInputs); | |||
| }; | |||
| } | |||
| private static void _restore_gradient_functions(FuncGraph func_graph, Dictionary<string, ConcreteFunction> renamed_functions, Dictionary<string, ConcreteFunction> loaded_gradients) | |||
| { | |||
| foreach(var op in func_graph.get_operations()) | |||
| @@ -210,14 +239,14 @@ namespace Tensorflow.Training.Saving.SavedModel | |||
| if(op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall") | |||
| { | |||
| var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name]; | |||
| // TODO(Rinne): deal with `op._gradient_function`. | |||
| op.op._gradient_function = function._get_gradient_function(); | |||
| } | |||
| string gradient_op_type = null; | |||
| try | |||
| { | |||
| gradient_op_type = op.op.get_attr("_gradient_op_type") as string; | |||
| } | |||
| catch(Exception e) | |||
| catch(InvalidArgumentError) | |||
| { | |||
| continue; | |||
| } | |||
| @@ -389,7 +418,7 @@ namespace Tensorflow.Training.Saving.SavedModel | |||
| concrete_function.ArgKeywords = saved_bare_concrete_function.ArgumentKeywords.ToList(); | |||
| concrete_function.NumPositionArgs = saved_bare_concrete_function.AllowedPositionalArguments; | |||
| var function_spec = _deserialize_function_spec_as_nonmethod(saved_bare_concrete_function.FunctionSpec); | |||
| //var function_spec = _deserialize_function_spec_as_nonmethod(saved_bare_concrete_function.FunctionSpec); | |||
| // TODO(Rinne): set the functiona spec. | |||
| concrete_function.AddTograph(); | |||
| return concrete_function; | |||
| @@ -413,19 +442,31 @@ namespace Tensorflow.Training.Saving.SavedModel | |||
| return function.CallFlat(inputs, function.CapturedInputs); | |||
| } | |||
| private static bool _concrete_function_callable_with(ConcreteFunction function, Tensors inputs, bool allow_conversion) | |||
| private static bool _concrete_function_callable_with(ConcreteFunction function, Tensor[] inputs, bool allow_conversion) | |||
| { | |||
| // TODO(Rinne): revise it. | |||
| return true; | |||
| return function.CapturedInputs.Length + inputs.Length == function.Inputs.Length; | |||
| //var expected_inputs = function.func_graph.Inputs; | |||
| //foreach(var (arg, expected) in zip(inputs, expected_inputs)) | |||
| //{ | |||
| // if(arg.Id != expected.Id) | |||
| // { | |||
| // return false; | |||
| // } | |||
| //} | |||
| //return true; | |||
| } | |||
| } | |||
| public class RestoredFunction : Function | |||
| { | |||
| public RestoredFunction(Func<Tensors, Tensors> function, ConcreteFunction concrete_function, | |||
| string name, bool auto_graph = true): base(function, name, auto_graph) | |||
| IEnumerable<ConcreteFunction> _concrete_functions; | |||
| FunctionSpec _function_spec; | |||
| public RestoredFunction(Func<Tensor[], Tensor[]> function, string name, FunctionSpec function_spec, | |||
| IEnumerable<ConcreteFunction> concrete_functions): base(function, name, auto_graph: false) | |||
| { | |||
| _concrete_variable_creation_fn = concrete_function; | |||
| _concrete_functions = concrete_functions; | |||
| _function_spec = function_spec; | |||
| } | |||
| protected override bool _run_functions_eagerly() | |||
| @@ -102,6 +102,6 @@ public class SignatureMap: Trackable | |||
| return new Dictionary<string, Trackable>(); | |||
| } | |||
| return _signatures.TakeWhile(x => x.Value is Function or ConcreteFunction).ToDictionary(x => x.Key, x => x.Value); | |||
| return _signatures.Where(x => x.Value is Function or ConcreteFunction).ToDictionary(x => x.Key, x => x.Value); | |||
| } | |||
| } | |||
| @@ -132,8 +132,8 @@ namespace Tensorflow.Training | |||
| { | |||
| get | |||
| { | |||
| var trainable_extra_variables = _self_extra_variables.TakeWhile(x => x.Trainable).ToList(); | |||
| var non_trainable_extra_variables = _self_extra_variables.TakeWhile(x => !x.Trainable).ToList(); | |||
| var trainable_extra_variables = _self_extra_variables.Where(x => x.Trainable).ToList(); | |||
| var non_trainable_extra_variables = _self_extra_variables.Where(x => !x.Trainable).ToList(); | |||
| List<IVariableV1> non_trainable_variables = new(); | |||
| foreach(var obj in Values) | |||
| { | |||
| @@ -576,7 +576,7 @@ namespace Tensorflow.Training | |||
| if(save_type == SaveType.SAVEDMODEL) | |||
| { | |||
| children = children.Concat(this.TakeWhile(x => x is Function or ConcreteFunction).Select((x, idx) => new KeyValuePair<string, Trackable>(idx.ToString(), x))).ToDictionary(x => x.Key, x => x.Value); | |||
| children = children.Concat(this.Where(x => x is Function or ConcreteFunction).Select((x, idx) => new KeyValuePair<string, Trackable>(idx.ToString(), x))).ToDictionary(x => x.Key, x => x.Value); | |||
| } | |||
| return children; | |||
| @@ -0,0 +1,24 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Util | |||
| { | |||
| internal static class ProtoUtils | |||
| { | |||
| public static object GetSingleAttrValue(AttrValue value, AttrValue.ValueOneofCase valueCase) | |||
| { | |||
| return valueCase switch | |||
| { | |||
| AttrValue.ValueOneofCase.S => value.S, | |||
| AttrValue.ValueOneofCase.I => value.I, | |||
| AttrValue.ValueOneofCase.F => value.F, | |||
| AttrValue.ValueOneofCase.B => value.B, | |||
| AttrValue.ValueOneofCase.Type => value.Type, | |||
| AttrValue.ValueOneofCase.Shape => value.Shape, | |||
| AttrValue.ValueOneofCase.Tensor => value.Tensor, | |||
| AttrValue.ValueOneofCase.Func => value.Func, | |||
| }; | |||
| } | |||
| } | |||
| } | |||
| @@ -7,15 +7,15 @@ namespace Tensorflow.Util | |||
| { | |||
| internal static class function_utils | |||
| { | |||
| private static string _rewriter_config_optimizer_disabled; | |||
| public static string get_disabled_rewriter_config() | |||
| private static ByteString _rewriter_config_optimizer_disabled; | |||
| public static ByteString get_disabled_rewriter_config() | |||
| { | |||
| if(_rewriter_config_optimizer_disabled is null) | |||
| { | |||
| var config = new ConfigProto(); | |||
| var rewriter_config = config.GraphOptions.RewriteOptions; | |||
| rewriter_config.DisableMetaOptimizer = true; | |||
| _rewriter_config_optimizer_disabled = config.ToString(); | |||
| _rewriter_config_optimizer_disabled = config.ToByteString(); | |||
| } | |||
| return _rewriter_config_optimizer_disabled; | |||
| } | |||
| @@ -137,10 +137,12 @@ namespace Tensorflow.Util | |||
| switch (instance) | |||
| { | |||
| case Hashtable hash: | |||
| var result = new Hashtable(); | |||
| foreach ((object key, object value) in zip<object, object>(_sorted(hash), args)) | |||
| result[key] = value; | |||
| return result; | |||
| { | |||
| var result = new Hashtable(); | |||
| foreach ((object key, object value) in zip<object, object>(_sorted(hash), args)) | |||
| result[key] = value; | |||
| return result; | |||
| } | |||
| } | |||
| } | |||
| //else if( _is_namedtuple(instance) || _is_attrs(instance)) | |||
| @@ -221,6 +223,16 @@ namespace Tensorflow.Util | |||
| return list; | |||
| } | |||
| public static List<T> flatten<T>(IEnumerable<T> structure) | |||
| { | |||
| var list = new List<T>(); | |||
| foreach(var item in structure) | |||
| { | |||
| _flatten_recursive(item, list); | |||
| } | |||
| return list; | |||
| } | |||
| public static object[] flatten2(ICanBeFlattened structure) | |||
| => structure.Flatten(); | |||
| @@ -527,6 +539,14 @@ namespace Tensorflow.Util | |||
| return pack_sequence_as(structure, mapped_flat_structure) as T2; | |||
| } | |||
| public static IEnumerable<T2> map_structure<T1, T2>(Func<T1, T2> func, IEnumerable<T1> structure) where T2 : class | |||
| { | |||
| var flat_structure = flatten(structure); | |||
| var mapped_flat_structure = flat_structure.Select(func).Select(x => (object)x); | |||
| return pack_sequence_as(structure, mapped_flat_structure) as IEnumerable<T2>; | |||
| } | |||
| /// <summary> | |||
| /// Same as map_structure, but with only one structure (no combining of multiple structures) | |||
| /// </summary> | |||
| @@ -0,0 +1,33 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Framework; | |||
| namespace Tensorflow.Util | |||
| { | |||
| internal static class variable_utils | |||
| { | |||
| public static Tensor[] convert_variables_to_tensors(object[] values) | |||
| { | |||
| return values.Select(x => | |||
| { | |||
| if (resource_variable_ops.is_resource_variable(x)) | |||
| { | |||
| return ops.convert_to_tensor(x); | |||
| } | |||
| else if (x is CompositeTensor) | |||
| { | |||
| throw new NotImplementedException("The composite tensor has not been fully supported."); | |||
| } | |||
| else if(x is Tensor tensor) | |||
| { | |||
| return tensor; | |||
| } | |||
| else | |||
| { | |||
| throw new TypeError("Currently the output of function to be traced must be `Tensor`."); | |||
| } | |||
| }).ToArray(); | |||
| } | |||
| } | |||
| } | |||
| @@ -248,7 +248,7 @@ namespace Tensorflow | |||
| foreach (var attr in node_def.Attr) | |||
| { | |||
| var bytes = attr.Value.ToByteArray(); | |||
| c_api.TF_SetAttrValueProto(op_desc, attr.Key, bytes, proto_len: bytes.Length, status: status); | |||
| c_api.TF_SetAttrValueProto(op_desc, attr.Key, bytes, proto_len: (ulong)bytes.Length, status: status); | |||
| status.Check(true); | |||
| } | |||
| @@ -575,10 +575,12 @@ namespace Tensorflow | |||
| public static HandleData get_resource_handle_data(Tensor graph_op) | |||
| { | |||
| throw new NotImplementedException(); | |||
| // This implementation hasn't been checked for some reasons. | |||
| // If it throws an exception in the future, please check it. | |||
| var handle_data = c_api.GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); | |||
| return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data))); | |||
| //var handle_data = c_api.GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); | |||
| //return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data))); | |||
| } | |||
| public static void dismantle_graph(Graph graph) | |||
| @@ -35,6 +35,10 @@ namespace Tensorflow.Keras.Engine | |||
| { | |||
| (x, y) = data_handler.DataAdapter.Expand1d(x, y); | |||
| using var tape = tf.GradientTape(); | |||
| //foreach (var variable in TrainableVariables) | |||
| //{ | |||
| // tape.watch(variable.Handle); | |||
| //} | |||
| var y_pred = Apply(x, training: true); | |||
| var loss = compiled_loss.Call(y, y_pred); | |||
| @@ -84,8 +84,8 @@ namespace Tensorflow.Keras.Layers | |||
| inputs.Insert(index, value); | |||
| } | |||
| var (c_op, _) = ops._create_c_op(graph, node_def, inputs.ToArray(), new Operation[0]); | |||
| var op = graph._create_op_from_tf_operation(c_op); | |||
| var (c_op, op_desc) = ops._create_c_op(graph, node_def, inputs.ToArray(), new Operation[0]); | |||
| var op = graph._create_op_from_tf_operation(c_op, desc: op_desc); | |||
| op._control_flow_post_processing(); | |||
| // Record the gradient because custom-made ops don't go through the | |||
| @@ -51,9 +51,9 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||
| _all_functions = new HashSet<string>(objects_and_functions.Item2); | |||
| } | |||
| public IDictionary<string, Trackable> Functions => _function_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); | |||
| public IDictionary<string, Trackable> Functions => _function_dict.Where(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); | |||
| public IDictionary<string, Trackable> CheckpointableObjects => _object_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); | |||
| public IDictionary<string, Trackable> CheckpointableObjects => _object_dict.Where(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); | |||
| /// <summary> | |||
| /// Returns functions to attach to the root object during serialization. | |||
| @@ -82,7 +82,7 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||
| { | |||
| get | |||
| { | |||
| var objects = CheckpointableObjects.TakeWhile( x=> _all_checkpointable_objects.Contains(x.Key)).ToDictionary(x => x.Key, x => x.Value); | |||
| var objects = CheckpointableObjects.Where( x=> _all_checkpointable_objects.Contains(x.Key)).ToDictionary(x => x.Key, x => x.Value); | |||
| objects[Constants.KERAS_ATTR] = _keras_trackable; | |||
| return objects; | |||
| } | |||
| @@ -0,0 +1,6 @@ | |||
| ¯root"_tf_keras_sequential*Š{"name": "sequential", "trainable": true, "expects_training_arg": true, "dtype": "float32", "batch_input_shape": null, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": false, "class_name": "Sequential", "config": {"name": "sequential", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 784]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}}, {"class_name": "Transformer", "config": {"name": "transformer", "trainable": true, "dtype": "float32", "a": 784, "b": 10}}, {"class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}}]}, "shared_object_id": 3, "input_spec": [{"class_name": "InputSpec", "config": {"dtype": null, "shape": {"class_name": "__tuple__", "items": [null, 784]}, "ndim": 2, "max_ndim": null, "min_ndim": null, "axes": {}}}], "build_input_shape": {"class_name": "TensorShape", "items": [null, 784]}, "is_graph_network": true, "full_save_spec": {"class_name": "__tuple__", "items": [[{"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 784]}, "float32", "input_1"]}], {}]}, "save_spec": {"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 784]}, "float32", "input_1"]}, "keras_version": "2.11.0", "backend": "tensorflow", "model_config": {"class_name": "Sequential", "config": {"name": "sequential", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 784]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}, "shared_object_id": 0}, {"class_name": "Transformer", "config": {"name": "transformer", "trainable": true, "dtype": "float32", "a": 784, "b": 10}, "shared_object_id": 1}, {"class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}, "shared_object_id": 2}]}}, "training_config": {"loss": "sparse_categorical_crossentropy", "metrics": [[{"class_name": "MeanMetricWrapper", "config": {"name": "accuracy", "dtype": "float32", "fn": "categorical_accuracy"}, "shared_object_id": 5}]], "weighted_metrics": null, "loss_weights": null, "optimizer_config": {"class_name": "Custom>Adam", "config": {"name": "Adam", "weight_decay": null, "clipnorm": null, "global_clipnorm": null, "clipvalue": null, "use_ema": false, "ema_momentum": 0.99, "ema_overwrite_frequency": null, "jit_compile": false, "is_legacy_optimizer": false, "learning_rate": 0.0010000000474974513, "beta_1": 0.9, "beta_2": 0.999, "epsilon": 1e-07, "amsgrad": false}}}}2 | |||
| ÿroot.layer_with_weights-0"_tf_keras_layer*È{"name": "transformer", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Transformer", "config": {"name": "transformer", "trainable": true, "dtype": "float32", "a": 784, "b": 10}, "shared_object_id": 1, "build_input_shape": {"class_name": "TensorShape", "items": [null, 784]}}2 | |||
| Þroot.layer-1"_tf_keras_layer*´{"name": "softmax", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}, "shared_object_id": 2, "build_input_shape": {"class_name": "TensorShape", "items": [null, 10]}}2 | |||
| ¸9root.keras_api.metrics.0"_tf_keras_metric*�{"class_name": "Mean", "name": "loss", "dtype": "float32", "config": {"name": "loss", "dtype": "float32"}, "shared_object_id": 6}2 | |||
| ë:root.keras_api.metrics.1"_tf_keras_metric*´{"class_name": "MeanMetricWrapper", "name": "accuracy", "dtype": "float32", "config": {"name": "accuracy", "dtype": "float32", "fn": "categorical_accuracy"}, "shared_object_id": 5}2 | |||
| @@ -6,7 +6,6 @@ using Tensorflow.Keras.Optimizers; | |||
| using Tensorflow.Keras.UnitTest.Helpers; | |||
| using Tensorflow.NumPy; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.KerasApi; | |||
| namespace TensorFlowNET.Keras.UnitTest.SaveModel; | |||
| @@ -62,11 +61,26 @@ public class SequentialModelLoad | |||
| [TestMethod] | |||
| public void Temp() | |||
| { | |||
| var model = tf.keras.models.load_model(@"D:\development\tf.net\tf_test\python_func"); | |||
| var model = tf.keras.models.load_model(@"Assets/python_func_model"); | |||
| model.summary(); | |||
| var x = tf.ones((2, 10)); | |||
| var x = tf.random.uniform((8, 784), -1, 1); | |||
| var y = model.Apply(x); | |||
| Console.WriteLine(y); | |||
| //model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | |||
| //var data_loader = new MnistModelLoader(); | |||
| //var num_epochs = 1; | |||
| //var batch_size = 8; | |||
| //var dataset = data_loader.LoadAsync(new ModelLoadSetting | |||
| //{ | |||
| // TrainDir = "mnist", | |||
| // OneHot = false, | |||
| // ValidationSize = 58000, | |||
| //}).Result; | |||
| //model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | |||
| } | |||
| } | |||
| @@ -49,6 +49,22 @@ | |||
| <None Update="Assets\simple_model_from_auto_compile\bias0.npy"> | |||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||
| </None> | |||
| <None Update="Assets\python_func_model\fingerprint.pb"> | |||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||
| </None> | |||
| <None Update="Assets\python_func_model\keras_metadata.pb"> | |||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||
| </None> | |||
| <None Update="Assets\python_func_model\saved_model.pb"> | |||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||
| </None> | |||
| <None Update="Assets\python_func_model\variables\variables.data-00000-of-00001"> | |||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||
| </None> | |||
| <None Update="Assets\python_func_model\variables\variables.index"> | |||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||
| </None> | |||
| </ItemGroup> | |||
| </Project> | |||
| @@ -413,7 +413,7 @@ namespace Tensorflow.Native.UnitTest | |||
| ASSERT_EQ(TF_OK, s_.Code, s_.Message); | |||
| ASSERT_NE(func_, IntPtr.Zero); | |||
| ASSERT_EQ(func_name_, c_api.StringPiece(c_api.TF_FunctionName(func_))); | |||
| c_api.TF_GraphCopyFunction(host_graph_, func_, IntPtr.Zero, s_); | |||
| c_api.TF_GraphCopyFunction(host_graph_, func_, new SafeFuncGraphHandle(IntPtr.Zero), s_); | |||
| ASSERT_EQ(TF_OK, s_.Code, s_.Message); | |||
| } | |||