Browse Source

Partially support the backward of loaded function model.

tags/v0.100.5-BERT-load
Yaohui Liu 2 years ago
parent
commit
fd1eb40f25
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
65 changed files with 1886 additions and 255 deletions
  1. +31
    -0
      Tensorflow.Common/Extensions/DictionaryExtension.cs
  2. +1
    -1
      src/TensorFlowNET.Core/APIs/tf.gradients.cs
  3. +7
    -0
      src/TensorFlowNET.Core/APIs/tf.tensor.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Attributes/c_api.ops.cs
  5. +1
    -0
      src/TensorFlowNET.Core/Binding.Util.cs
  6. +6
    -0
      src/TensorFlowNET.Core/Buffers/Buffer.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs
  8. +1
    -0
      src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs
  9. +86
    -3
      src/TensorFlowNET.Core/Contexts/Context.Config.cs
  10. +40
    -4
      src/TensorFlowNET.Core/Contexts/Context.cs
  11. +3
    -2
      src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs
  12. +1
    -0
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs
  13. +53
    -0
      src/TensorFlowNET.Core/Eager/backprop_util.cs
  14. +2
    -2
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  15. +0
    -6
      src/TensorFlowNET.Core/Framework/Models/ScopedTFFunction.cs
  16. +22
    -0
      src/TensorFlowNET.Core/Framework/ScopedTFFunction.cs
  17. +12
    -3
      src/TensorFlowNET.Core/Framework/function_def_lib.cs
  18. +34
    -22
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  19. +92
    -25
      src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs
  20. +2
    -2
      src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs
  21. +24
    -9
      src/TensorFlowNET.Core/Functions/Function.cs
  22. +37
    -23
      src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs
  23. +84
    -0
      src/TensorFlowNET.Core/Functions/TracingCompiler.cs
  24. +5
    -1
      src/TensorFlowNET.Core/Functions/c_api.function.cs
  25. +50
    -0
      src/TensorFlowNET.Core/Functions/composite_tensor_utils.cs
  26. +10
    -5
      src/TensorFlowNET.Core/Functions/function_saved_model_utils.cs
  27. +248
    -20
      src/TensorFlowNET.Core/Functions/monomorphic_function.cs
  28. +52
    -0
      src/TensorFlowNET.Core/Gradients/default_gradient.cs
  29. +91
    -4
      src/TensorFlowNET.Core/Gradients/gradients_util.cs
  30. +15
    -4
      src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs
  31. +345
    -10
      src/TensorFlowNET.Core/Graphs/FuncGraph.cs
  32. +8
    -1
      src/TensorFlowNET.Core/Graphs/Graph.Gradient.cs.cs
  33. +1
    -1
      src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
  34. +29
    -2
      src/TensorFlowNET.Core/Graphs/Graph.cs
  35. +37
    -0
      src/TensorFlowNET.Core/Graphs/GraphOverrideGradientContext.cs
  36. +94
    -13
      src/TensorFlowNET.Core/Operations/Operation.cs
  37. +4
    -4
      src/TensorFlowNET.Core/Operations/c_api.ops.cs
  38. +1
    -1
      src/TensorFlowNET.Core/Operations/functional_ops.cs
  39. +45
    -0
      src/TensorFlowNET.Core/Operations/gen_functional_ops.cs
  40. +38
    -0
      src/TensorFlowNET.Core/Operations/gen_ops.cs
  41. +2
    -0
      src/TensorFlowNET.Core/Operations/handle_data_util.cs
  42. +46
    -25
      src/TensorFlowNET.Core/Operations/resource_variable_ops.cs
  43. +1
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  44. +6
    -4
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  45. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensors.cs
  46. +6
    -5
      src/TensorFlowNET.Core/Training/Saving/SavedModel/WrapperFunction.cs
  47. +68
    -27
      src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs
  48. +1
    -1
      src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs
  49. +3
    -3
      src/TensorFlowNET.Core/Training/data_structures.cs
  50. +24
    -0
      src/TensorFlowNET.Core/Util/ProtoUtils.cs
  51. +3
    -3
      src/TensorFlowNET.Core/Util/function_utils.cs
  52. +24
    -4
      src/TensorFlowNET.Core/Util/nest.py.cs
  53. +33
    -0
      src/TensorFlowNET.Core/Util/variable_utils.cs
  54. +5
    -3
      src/TensorFlowNET.Core/ops.cs
  55. +4
    -0
      src/TensorFlowNET.Keras/Engine/Model.Train.cs
  56. +2
    -2
      src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs
  57. +3
    -3
      src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs
  58. BIN
      test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/fingerprint.pb
  59. +6
    -0
      test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/keras_metadata.pb
  60. BIN
      test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/saved_model.pb
  61. BIN
      test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/variables/variables.data-00000-of-00001
  62. BIN
      test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/variables/variables.index
  63. +17
    -3
      test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs
  64. +16
    -0
      test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj
  65. +1
    -1
      test/TensorFlowNET.Native.UnitTest/Functions/FunctionTest.cs

+ 31
- 0
Tensorflow.Common/Extensions/DictionaryExtension.cs View File

@@ -0,0 +1,31 @@
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;

namespace Tensorflow.Common.Extensions
{
public static class DictionaryExtension
{
public static void Deconstruct<T1, T2>(this KeyValuePair<T1, T2> pair, out T1 first, out T2 second)
{
first = pair.Key;
second = pair.Value;
}
public static void Update<T1, T2>(this Dictionary<T1, T2> dic, IDictionary<T1, T2> other)
{
foreach(var (key, value) in other)
{
dic[key] = value;
}
}
public static T2 GetOrDefault<T1, T2>(this Dictionary<T1, T2> dic, T1 key, T2 defaultValue)
{
if (dic.ContainsKey(key))
{
return dic[key];
}
return defaultValue;
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/APIs/tf.gradients.cs View File

@@ -21,7 +21,7 @@ namespace Tensorflow
{
public partial class tensorflow
{
GradientTape _tapeSet;
internal GradientTape _tapeSet;

/// <summary>
/// Record operations for automatic differentiation.


+ 7
- 0
src/TensorFlowNET.Core/APIs/tf.tensor.cs View File

@@ -14,6 +14,8 @@
limitations under the License.
******************************************************************************/

using Tensorflow.Operations;

namespace Tensorflow
{
public partial class tensorflow
@@ -79,5 +81,10 @@ namespace Tensorflow
num_split: num_split,
axis: axis,
name: name);

public Tensor ensure_shape(Tensor x, Shape shape, string name = null)
{
return gen_ops.ensure_shape(x, shape, name);
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Attributes/c_api.ops.cs View File

@@ -61,7 +61,7 @@ namespace Tensorflow
public static extern void TF_SetAttrBool(IntPtr desc, string attr_name, bool value);

[DllImport(TensorFlowLibName)]
public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status);
public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, ulong proto_len, SafeStatusHandle status);

/// <summary>
/// Set `num_dims` to -1 to represent "unknown rank".


+ 1
- 0
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -22,6 +22,7 @@ using System.ComponentModel;
using System.Diagnostics;
using System.IO;
using System.Linq;
using Tensorflow.Operations;

namespace Tensorflow
{


+ 6
- 0
src/TensorFlowNET.Core/Buffers/Buffer.cs View File

@@ -107,6 +107,12 @@ namespace Tensorflow
}
}

public void Release()
{
_handle.Dispose();
_handle = null;
}

public override string ToString()
=> $"0x{_handle.DangerousGetHandle():x16}";



+ 1
- 1
src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs View File

@@ -161,7 +161,7 @@ public static class CheckPointUtils

internal static IEnumerable<Trackable> _objects_with_attributes(IEnumerable<Trackable> full_list)
{
return full_list.TakeWhile(x =>
return full_list.Where(x =>
{
var saveables = x.gather_saveables_for_checkpoint();
return saveables is not null && saveables.Count > 0;


+ 1
- 0
src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs View File

@@ -109,6 +109,7 @@ namespace Tensorflow.Checkpoint
TrackableObjectGraph.Types.TrackableObject trackable_object = new();
trackable_object.SlotVariables.AddRange(td.slot_variable_proto);
trackable_object.Children.AddRange(td.children_proto);
object_graph_proto.Nodes.Add(trackable_object);
}
return object_graph_proto;
}


+ 86
- 3
src/TensorFlowNET.Core/Contexts/Context.Config.cs View File

@@ -14,9 +14,11 @@
limitations under the License.
******************************************************************************/

using Google.Protobuf;
using System;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Common.Extensions;

namespace Tensorflow.Contexts
{
@@ -25,12 +27,93 @@ namespace Tensorflow.Contexts
/// </summary>
public sealed partial class Context
{
public ConfigProto Config { get; set; } = new ConfigProto
protected Device.PhysicalDevice[] _physical_devices;
protected Dictionary<Device.PhysicalDevice, int> _physical_device_to_index;
ConfigProto _config;
public ConfigProto Config
{
GpuOptions = new GPUOptions
get
{
_initialize_physical_devices();

var config = new ConfigProto();
if(_config is not null)
{
config.MergeFrom(_config);
}
config.LogDevicePlacement = _log_device_placement;

config.DeviceCount["CPU"] = 0;
config.DeviceCount["GPU"] = 0;
foreach(var dev in _physical_devices)
{
if (config.DeviceCount.ContainsKey(dev.DeviceType))
{
config.DeviceCount[dev.DeviceType] += 1;
}
else
{
config.DeviceCount[dev.DeviceType] = 1;
}
}

var gpu_options = _compute_gpu_options();
config.GpuOptions = GPUOptions.Parser.ParseFrom(gpu_options.ToByteArray());

return config;
}
set
{
_config = value;
}
}

protected void _initialize_physical_devices(bool reinitialize = false)
{
if(!reinitialize && _physical_devices is not null)
{
return;
}
var devs = list_physical_devices();
_physical_devices = devs.Select(d => new Device.PhysicalDevice()
{
DeviceName = d.DeviceName,
DeviceType = d.DeviceType
}).ToArray();
_physical_device_to_index = _physical_devices.Select((p, i) => new KeyValuePair<Device.PhysicalDevice, int>(p, i))
.ToDictionary(x => x.Key, x => x.Value);

_import_config();
}

protected void _import_config()
{
if(_config is null)
{
return;
}
if(!_config.DeviceCount.TryGetValue("CPU", out var num_cpus))
{
num_cpus = 1;
}
if(num_cpus != 1)
{
// TODO(Rinne): implement it.
}
};

var gpus = _physical_devices.Where(d => d.DeviceType == "GPU");
if(gpus.Count() == 0)
{
return;
}

if(!_config.DeviceCount.TryGetValue("GPU", out var gpu_count))
{
gpu_count = 0;
}

// TODO(Rinne): implement it.
}

ConfigProto MergeConfig()
{


+ 40
- 4
src/TensorFlowNET.Core/Contexts/Context.cs View File

@@ -38,7 +38,26 @@ namespace Tensorflow.Contexts
public string ScopeName { get; set; } = "";
bool initialized = false;
ContextSwitchStack context_switches;
public FunctionCallOptions FunctionCallOptions { get; }
protected FunctionCallOptions _function_call_options;
public FunctionCallOptions FunctionCallOptions
{
get
{
if(_function_call_options is null)
{
var config = Config;
_function_call_options = new FunctionCallOptions()
{
Config = config
};
}
return _function_call_options;
}
set
{
_function_call_options = value;
}
}

SafeContextHandle _handle;

@@ -62,7 +81,6 @@ namespace Tensorflow.Contexts
if (initialized)
return;

Config = MergeConfig();
FunctionCallOptions.Config = Config;
var config_str = Config.ToByteArray();
var opts = new ContextOptions();
@@ -167,11 +185,29 @@ namespace Tensorflow.Contexts
return c_api.TFE_ContextHasFunction(_handle, name);
}

public void add_function(SafeFuncGraphHandle fn)
{
ensure_initialized();
Status status = new();
c_api.TFE_ContextAddFunction(_handle, fn, status);
status.Check(true);
}

public void remove_function(string name)
{
ensure_initialized();
Status status = new();
c_api.TFE_ContextRemoveFunction(_handle, name, status);
status.Check(true);
}

public void add_function_def(FunctionDef fdef)
{
ensure_initialized();
var fdef_string = fdef.ToString();
c_api.TFE_ContextAddFunctionDef(_handle, fdef_string, fdef_string.Length);
var fdef_string = fdef.ToByteArray();
Status status = new Status();
c_api.TFE_ContextAddFunctionDef(_handle, fdef_string, (ulong)fdef_string.Length, status);
status.Check(true);
}

public void restore_mode()


+ 3
- 2
src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs View File

@@ -9,10 +9,11 @@ namespace Tensorflow.Contexts
public class FunctionCallOptions
{
public ConfigProto Config { get; set; }
public string ExecutorType { get; set; }

public string config_proto_serialized()
public ByteString config_proto_serialized()
{
return Config.ToByteString().ToStringUtf8();
return Config.ToByteString();
}
}
}

+ 1
- 0
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs View File

@@ -17,6 +17,7 @@
using System;
using System.Linq;
using Tensorflow.Contexts;
using Tensorflow.Functions;
using static Tensorflow.Binding;

namespace Tensorflow.Eager


+ 53
- 0
src/TensorFlowNET.Core/Eager/backprop_util.cs View File

@@ -0,0 +1,53 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Operations;

namespace Tensorflow.Eager
{
internal static class backprop_util
{
// TODO: add quantized_dtypes (after being supported).
private static HashSet<TF_DataType> _trainable_dtypes = new HashSet<TF_DataType>(new TF_DataType[]
{
dtypes.float16, dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128,
dtypes.resource, dtypes.variant, TF_DataType.TF_BFLOAT16
});
public static bool IsTrainable(Tensor tensor)
{
var dtype = _DTypeFromTensor(tensor);
return _trainable_dtypes.Contains(dtype);
}
public static bool IsTrainable(TF_DataType dtype)
{
return _trainable_dtypes.Contains(dtype);
}

private static TF_DataType _DTypeFromTensor(Tensor tensor)
{
var dtype = tensor.dtype;
if(dtype.as_base_dtype() == TF_DataType.TF_VARIANT)
{
CppShapeInferenceResult.Types.HandleData handle_data;
if (tensor is EagerTensor)
{
handle_data = tensor.HandleData;
}
else
{
handle_data = handle_data_util.get_resource_handle_data(tensor);
}
if(handle_data is not null && handle_data.IsSet && handle_data.ShapeAndType is not null &&
handle_data.ShapeAndType.Count > 0)
{
var first_type = handle_data.ShapeAndType[0].Dtype;
if(first_type != DataType.DtInvalid && handle_data.ShapeAndType.All(x => x.Dtype == first_type))
{
return first_type.as_tf_dtype();
}
}
}
return dtype;
}
}
}

+ 2
- 2
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -31,7 +31,7 @@ namespace Tensorflow
public static extern void TFE_ContextOptionsSetConfig(SafeContextOptionsHandle opts, byte[] proto, ulong proto_len, SafeStatusHandle status);

[DllImport(TensorFlowLibName)]
public static extern void TFE_ContextAddFunctionDef(SafeContextHandle ctx, string serialized_function_def, int size);
public static extern void TFE_ContextAddFunctionDef(SafeContextHandle ctx, byte[] serialized_function_def, ulong size, SafeStatusHandle status);

[DllImport(TensorFlowLibName)]
public static extern void TFE_ContextOptionsSetDevicePlacementPolicy(SafeContextOptionsHandle opts, ContextDevicePlacementPolicy device_policy);
@@ -280,7 +280,7 @@ namespace Tensorflow
public static extern void TFE_OpSetAttrIntList(SafeEagerOpHandle op, string attr_name, long[] values, int num_values);

[DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrValueProto(SafeEagerOpHandle op, string attr_name, IMessage[] proto, int proto_len, SafeStatusHandle status);
public static extern void TFE_OpSetAttrValueProto(IntPtr op, string attr_name, IntPtr proto, ulong proto_len, SafeStatusHandle status);

/// <summary>
///


+ 0
- 6
src/TensorFlowNET.Core/Framework/Models/ScopedTFFunction.cs View File

@@ -1,6 +0,0 @@
namespace Tensorflow.Framework.Models
{
class ScopedTFFunction
{
}
}

+ 22
- 0
src/TensorFlowNET.Core/Framework/ScopedTFFunction.cs View File

@@ -0,0 +1,22 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Framework
{
internal class ScopedTFFunction
{
SafeFuncGraphHandle _handle;
string _name;
public ScopedTFFunction(SafeFuncGraphHandle func, string name)
{
_handle = func;
_name = name;
}

public SafeFuncGraphHandle Get()
{
return _handle;
}
}
}

+ 12
- 3
src/TensorFlowNET.Core/Framework/function_def_lib.cs View File

@@ -4,6 +4,7 @@ using System.Diagnostics;
using System.Security.Cryptography;
using System.Text;
using Tensorflow.Graphs;
using Tensorflow.Common.Extensions;
using static Tensorflow.Binding;
using static Tensorflow.CppShapeInferenceResult.Types;

@@ -64,7 +65,7 @@ namespace Tensorflow.Framework
{
output_names[ops.tensor_id(func_graph.get_tensor_by_name(tensor_name))] = ret_arg_def.Name;
}
// TODO(Rinne): func_graph._output_names = output_names
func_graph._output_names = output_names;

func_graph.Exit();
return func_graph;
@@ -154,9 +155,17 @@ namespace Tensorflow.Framework
foreach(var node_def in fdef.NodeDef)
{
var graph = default_graph;
// TODO(Rinne): The `Graph` lacks `_functions`, needed to be implemented in the future.
while(graph.OuterGraph is not null)
while (true)
{
if(graph is null)
{
break;
}
var f = graph.Functions.GetOrDefault(node_def.Op, null);
if(f is not null && graph.OuterGraph is null)
{
break;
}
graph = graph.OuterGraph;
}



+ 34
- 22
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -4,6 +4,7 @@ using System.Diagnostics;
using System.Linq;
using Tensorflow.Eager;
using Tensorflow.Framework.Models;
using Tensorflow.Gradients;
using Tensorflow.Graphs;
using Tensorflow.Train;
using Tensorflow.Util;
@@ -19,7 +20,7 @@ namespace Tensorflow.Functions
protected IEnumerable<Tensor> _captured_inputs;
internal FuncGraph func_graph;
protected DelayedRewriteGradientFunctions _delayed_rewrite_functions;
protected Dictionary<string, string> _attrs;
protected Dictionary<string, AttrValue> _attrs;
protected FunctionSpec _function_spec;
protected FunctionSpec _pre_initialized_function_spec = null;
protected EagerDefinedFunction _inference_function;
@@ -29,22 +30,25 @@ namespace Tensorflow.Functions

public string Name => _delayed_rewrite_functions.Forward().Name;

public Tensor[] Outputs;
public Tensor[] Outputs => func_graph.Outputs;
public Type ReturnType;
public TensorSpec[] OutputStructure;
public IEnumerable<string> ArgKeywords { get; set; }
public long NumPositionArgs { get; set; }
public FunctionDef FunctionDef => _delayed_rewrite_functions.Forward().Definition;
public Tensor[] FlatStructuredOutputs => func_graph.FlatStructuredOutputs;
public IEnumerable<IVariableV1> Variables => func_graph.Variables;
public IEnumerable<IVariableV1> TrainableVariables => func_graph.TrainableVariables;

public ConcreteFunction(string name)
{
func_graph = new FuncGraph(name);
_captured_inputs = func_graph.external_captures;
_attrs= new Dictionary<string, string>();
_attrs= new Dictionary<string, AttrValue>();
_set_infer_function();
}

public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs = null)
public ConcreteFunction(FuncGraph graph, Dictionary<string, AttrValue> attrs = null)
{
func_graph = graph;
_captured_inputs = func_graph.external_captures;
@@ -70,7 +74,7 @@ namespace Tensorflow.Functions
null);
func_graph.Exit();
_captured_inputs = func_graph.external_captures;
_attrs = new Dictionary<string, string>();
_attrs = new Dictionary<string, AttrValue>();
_set_infer_function();
}

@@ -93,7 +97,7 @@ namespace Tensorflow.Functions
null);
func_graph.Exit();
_captured_inputs = func_graph.external_captures;
_attrs = new Dictionary<string, string>();
_attrs = new Dictionary<string, AttrValue>();
_set_infer_function();
}

@@ -160,27 +164,20 @@ namespace Tensorflow.Functions
}
if (!executing_eagerly)
{
// TODO(Rinne): add the check
}
tensor_inputs.AddRange(captured_inputs);
tensor_inputs.AddRange(captured_inputs);

args = tensor_inputs.ToArray();

var possible_gradient_type = tf.Runner.MustRecordGradient() ? 1 : 0;
var possible_gradient_type = gradients_util.PossibleTapeGradientTypes(args);
// No tape is watching; skip to running the function.
if (possible_gradient_type == 0 && executing_eagerly)
if (possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_NONE && executing_eagerly)
{
return _build_call_outputs(_inference_function.Call(args));
//var attrs = new object[]
//{
// "executor_type", "",
// "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized()
//};
//return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs);
}

if (forward_backward == null)
forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly);
forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly);
var (forward_function, args_with_tangents) = forward_backward.Forward();
Tensors flat_outputs = null;
if (executing_eagerly)
@@ -189,8 +186,12 @@ namespace Tensorflow.Functions
}
else
{
// TODO(Rinne): add `default_graph._override_gradient_function`.
flat_outputs = forward_function.Call(args_with_tangents);
tf_with(default_graph._override_gradient_function(new Dictionary<string, Func<Operation, object[], Tensor[]>>(){
{ "PartitionedCall", _get_gradient_function() }, { "StatefulPartitionedCall", _get_gradient_function() }
}), _ =>
{
flat_outputs = forward_function.Call(args_with_tangents);
});
}
forward_backward.Record(flat_outputs);
return _build_call_outputs(flat_outputs);
@@ -215,7 +216,8 @@ namespace Tensorflow.Functions
TangentInfo input_tangents;
if (executing_eagerly)
{
throw new NotImplementedException();
// TODO(Rinne): check if it needs to be implemented.
input_tangents = new TangentInfo();
}
else
{
@@ -239,7 +241,12 @@ namespace Tensorflow.Functions
}

// TODO(Rinne): add arg "input_tagents" for ForwardBackwardCall.
return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: false);
return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: tf.Runner.MustRecordGradient());
}

internal void set_variables(IEnumerable<IVariableV1> variables)
{
func_graph.Variables = variables;
}

internal void _set_infer_function()
@@ -274,6 +281,11 @@ namespace Tensorflow.Functions
};
}

internal Func<Operation, object[], Tensor[]> _get_gradient_function()
{
return _delayed_rewrite_functions._rewrite_forward_and_call_backward;
}

private Tensors _build_call_outputs(Tensors result)
{
// TODO(Rinne): dwal with `func_graph.structured_outputs`


+ 92
- 25
src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs View File

@@ -9,18 +9,27 @@ using Tensorflow.Eager;
using Tensorflow.Graphs;
using Tensorflow.Operations;
using Tensorflow.Util;
using Tensorflow.Common.Extensions;
using static Tensorflow.Binding;
using Tensorflow.Framework;
using System.Buffers;
using Tensorflow.Gradients;

namespace Tensorflow.Functions
{
public class EagerDefinedFunction
public class EagerDefinedFunction: IDisposable
{
public int _num_outputs;
FuncGraph _func_graph;
FuncGraph _graph;
FunctionDef _definition;
OpDef _signature;
string _name;
Tensor[] _func_graph_outputs;
internal ScopedTFFunction _c_func;
internal Tensor[] _func_graph_outputs;
internal string _grad_func_name;
internal Func<Operation, Tensor[], Tensor[]> csharp_grad_func;
internal EagerDefinedFunction _grad_func;
internal bool _registered_on_context = false;
public string Name => _name;
public DataType[] OutputTypes { get; protected set; }
public Shape[] OutputShapes { get; protected set; }
@@ -47,48 +56,93 @@ namespace Tensorflow.Functions
return _signature;
}
}
public EagerDefinedFunction(string name, FuncGraph graph,
public unsafe EagerDefinedFunction(string name, FuncGraph graph,
Tensors inputs, Tensors outputs,
Dictionary<string, string> attrs)
Dictionary<string, AttrValue> attrs)
{
var input_ops = inputs.Select(x => x.op).ToArray();
var operations = graph.get_operations().Where(x => !input_ops.Contains(x.op))
.Select(x => x as Operation).ToArray();
var output_names = new string[0];
_func_graph = new FuncGraph(graph, name, attrs);
_func_graph_outputs = new List<Tensor>(outputs).ToArray();
_func_graph.ToGraph(operations, inputs, outputs, output_names);
var graph_output_names = graph._output_names;
string[] output_names;
if(graph_output_names is not null && outputs.All(t => graph_output_names.ContainsKey(ops.tensor_id(t))))
{
output_names = outputs.Select(t => graph_output_names[ops.tensor_id(t)]).ToArray();
if(output_names.Distinct().Count() != output_names.Length)
{
output_names = new string[0];
}
}
else
{
output_names = new string[0];
}

Status status = new Status();
var fn = c_api.TF_GraphToFunction(graph.c_graph,
name,
false,
operations.Length,
operations.Length == 0 ? new IntPtr[0] : operations.Select(x => (IntPtr)x).ToArray(),
inputs.Length,
inputs.Select(t => t._as_tf_output()).ToArray(),
outputs.Length,
outputs.Select(t => t._as_tf_output()).ToArray(),
output_names.Length != outputs.Length ? null : output_names,
IntPtr.Zero, // warning: the control output hasbben totally ignored.
null,
status);
status.Check(true);

_c_func = new ScopedTFFunction(fn, name);

foreach(var (attr_name, attr_value) in attrs)
{
var serialized = attr_value.ToByteArray();
c_api.TF_FunctionSetAttrValueProto(fn, attr_name, serialized, serialized.Length, status);
status.Check(true);
}

var signature = _get_definition().Signature;
_name = signature.Name;
// TODO(Rinne): deal with `fn`
tf_with(ops.init_scope(), s =>
{
tf.Context.add_function(fn);
_registered_on_context = true;
});

_num_outputs = signature.OutputArg.Count;
OutputTypes = signature.OutputArg.Select(x => x.Type).ToArray();
OutputShapes = outputs.Select(x => x.shape).ToArray();
_func_graph_outputs = new List<Tensor>(outputs).ToArray();
csharp_grad_func = null;
_graph = graph;
}

public Tensors Call(Tensors args)
public unsafe Tensors Call(Tensors args)
{
// TODO(Rinne): Add arg `CancellationManager`.
// TODO(Rinne): Check the arg length.
var function_call_options = tf.Context.FunctionCallOptions;
string config;
if (string.IsNullOrEmpty(function_call_options.config_proto_serialized()))
if (function_call_options.config_proto_serialized().Length == 0)
{
config = function_utils.get_disabled_rewriter_config();
config = function_utils.get_disabled_rewriter_config().ToString();
}
else
{
config = function_call_options.config_proto_serialized();
config = function_call_options.config_proto_serialized().ToString();
}
// TODO(Rinne): executor_type

config = ""; // TODO(Rinne): revise it.

string executor_type = function_call_options.ExecutorType ?? "";
var executing_eagerly = tf.Context.executing_eagerly();

var attrs = new object[]
{
"executor_type", "",
"config_proto", tf.Context.FunctionCallOptions.config_proto_serialized()
"executor_type", executor_type,
"config_proto", config
};

Tensor[] outputs;
@@ -103,9 +157,19 @@ namespace Tensorflow.Functions
}
else
{
tf.GradientTape().stop_recording();
outputs = functional_ops.partitioned_call(args, this, OutputTypes,
executing_eagerly, config, "");
if(tf.GetTapeSet().Count == 0)
{
outputs = functional_ops.partitioned_call(args, this, OutputTypes,
executing_eagerly, config, "");
}
else
{
var tape = tf.GetTapeSet().Peek();
tape.StopRecord();
outputs = functional_ops.partitioned_call(args, this, OutputTypes,
executing_eagerly, config, "");
tape.StartRecord();
}
}
foreach(var (i, func_graph_output) in enumerate(_func_graph_outputs))
{
@@ -141,7 +205,7 @@ namespace Tensorflow.Functions
{
g.AddFunction(this);
}
foreach(var f in _func_graph.Functions.Values)
foreach(var f in _graph.Functions.Values)
{
if (!g.IsFunction(f.Name))
{
@@ -155,12 +219,15 @@ namespace Tensorflow.Functions
{
var buffer = c_api_util.tf_buffer();
Status status = new();
c_api.TF_FunctionToFunctionDef(_func_graph._func_graph_handle, buffer, status);
c_api.TF_FunctionToFunctionDef(_c_func.Get(), buffer, status);
status.Check(true);
var proto_data = c_api.TF_GetBuffer(buffer);
FunctionDef function_def = new();
function_def.MergeFrom(proto_data.AsSpan<byte>());
return function_def;
return FunctionDef.Parser.ParseFrom(proto_data.AsSpan<byte>());
}

public void Dispose()
{
tf.Context.remove_function(Name);
}
}
}

+ 2
- 2
src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs View File

@@ -17,9 +17,9 @@ namespace Tensorflow.Functions
public override EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args)
{
var outputs = _func_graph.Outputs;
(_forward, _forward_graph, _backward, _forwardprop_output_indices, _num_forwardprop_outputs)
(_forward_function, _forward_graph, _backward_function, _forwardprop_output_indices, _num_forwardprop_outputs)
= BuildFunctionsForOutputs(outputs, inference_args);
return _forward;
return _forward_function;
}
}
}

+ 24
- 9
src/TensorFlowNET.Core/Functions/Function.cs View File

@@ -10,23 +10,26 @@ namespace Tensorflow
private IntPtr _handle;
#pragma warning restore CS0169 // The field 'Function._handle' is never used

protected Func<Tensors, Tensors> _function;
protected Func<Tensor[], Tensor[]> _csharp_function;
protected ConcreteFunction _concrete_variable_creation_fn;
protected bool _auto_graph;
protected bool _autograph;
protected TracingCompiler _variable_creation_fn;
protected bool _has_initialized;
public string Name { get; set; }
public Function(Func<Tensors, Tensors> function,
public Function(Func<Tensor[], Tensor[]> csharp_function,
string name, bool auto_graph = true)
{
_function = function;
_csharp_function = csharp_function;
Name = name;
_auto_graph = auto_graph;
_autograph = auto_graph;
_has_initialized = false;
}

public virtual Tensors Apply(Tensors inputs)
{
if (_run_functions_eagerly())
{
return _function(inputs);
return _csharp_function(inputs);
}

var result = _call(inputs);
@@ -35,20 +38,32 @@ namespace Tensorflow

protected virtual Tensors _call(Tensors inputs)
{
_initialize();
if (!_has_initialized)
{
_initialize(inputs);
}

return _concrete_variable_creation_fn.CallFlat(inputs,
_concrete_variable_creation_fn.CapturedInputs);
}

protected TracingCompiler _compiler(Func<Tensor[], Tensor[]> fn)
{
var name = nameof(fn);
return new TracingCompiler(fn, name, autograph: _autograph);
}

protected virtual bool _run_functions_eagerly()
{
return false;
}

private void _initialize()
private void _initialize(Tensor[] args)
{

_variable_creation_fn = _compiler(_csharp_function);
_variable_creation_fn._name = this.Name;
_concrete_variable_creation_fn = _variable_creation_fn._get_concrete_function_internal_garbage_collected(args);
_has_initialized = true;
}
}
}

+ 37
- 23
src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs View File

@@ -3,8 +3,10 @@ using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Eager;
using Tensorflow.Gradients;
using Tensorflow.Graphs;
using Tensorflow.NumPy;
using Tensorflow.Operations;
using static Tensorflow.Binding;
using static Tensorflow.tensorflow;

@@ -22,11 +24,11 @@ namespace Tensorflow.Functions
protected string _INFERENCE_PREFIX = "__inference_";

protected FuncGraph _func_graph;
protected EagerDefinedFunction _forward;
protected EagerDefinedFunction _forward_function;
protected FuncGraph _forward_graph;
protected List<int> _forwardprop_output_indices;
protected int _num_forwardprop_outputs;
protected ConcreteFunction _backward;
protected ConcreteFunction _backward_function;
BackwardFunction _backward_function_wrapper;

public TapeGradientFunctions(FuncGraph func_graph,
@@ -49,8 +51,8 @@ namespace Tensorflow.Functions
public virtual void Record(Tensors flat_outputs, Tensors inference_args)
{
// TODO(Rinne): add arg `input_tagents`.
var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs);
tf.Runner.RecordGradient(_forward.Name, inference_args, new object[0], to_record,
var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward_function, flat_outputs);
tf.Runner.RecordGradient(_forward_function.Name, inference_args, new object[0], to_record,
getBackwardFunction: backward_function);
}

@@ -134,46 +136,58 @@ namespace Tensorflow.Functions
var trainable_indices = new List<int>();
foreach(var (index, output) in enumerate(outputs))
{
if (gradients_util.IsTrainable(output))
if (backprop_util.IsTrainable(output))
{
trainable_outputs.Add(output);
trainable_indices.Add(index);
}
}

var gradients_wrt_outputs = new List<Tensor>();
var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}");
var backwards_graph = new FuncGraph(_func_graph.Name);
backwards_graph.as_default();
var gradients_wrt_outputs = new List<Tensor>();
foreach (var output in trainable_outputs)
gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape));
{
var (gradient_shape, gradient_dtype) = default_gradient.shape_and_dtype(output);
var gradient_placeholder = tf.placeholder(gradient_dtype, gradient_shape);
gradients_wrt_outputs.Add(gradient_placeholder);
handle_data_util.copy_handle_data(output, gradient_placeholder);
}
var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(),
_func_graph.Inputs,
grad_ys: gradients_wrt_outputs.ToArray(),
src_graph: _func_graph);
_func_graph.Inputs,
grad_ys: gradients_wrt_outputs.ToArray(),
src_graph: _func_graph);

var captures_from_forward = backwards_graph.external_captures
.Where(x => x is not EagerTensor && x is not NDArray && x.graph == _func_graph)
.ToArray();
HashSet<Tensor> existing_outputs = new(_func_graph.Outputs);
foreach(var capture in captures_from_forward)
{
if (!_func_graph.Outputs.Contains(capture))
if (!existing_outputs.Contains(capture))
{
existing_outputs.Add(capture);
_func_graph.Outputs.Add(capture);
}
}
backwards_graph.Exit();

var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}";
var backward_function_attr = new Dictionary<string, string>();
backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name;
gradients_wrt_outputs.append(backwards_graph.internal_captures);
backwards_graph.Inputs = gradients_wrt_outputs;
backwards_graph.Outputs = gradients_wrt_inputs;
backwards_graph.Inputs = gradients_wrt_outputs.Concat(backwards_graph.internal_captures).ToArray();
backwards_graph.Outputs.AddRange(gradients_wrt_inputs.Where(x => x is not null));

var (forward_function, backward_function) = monomorphic_function_utils._create_forward_backward_with_graph(null, _func_graph, backwards_graph);
//var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}";
//var backward_function_attr = new Dictionary<string, string>();
//backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name;

var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr);
//var backward_function = new ConcreteFunction(backwards_graph,
// monomorphic_function_utils._parse_func_attrs(backward_function_attr));
var forward_function_attr = new Dictionary<string, string>();
forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name;
var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph,
_func_graph.Inputs, _func_graph.Outputs, forward_function_attr);
//var forward_function_attr = new Dictionary<string, string>();
//forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name;
//var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph,
// _func_graph.Inputs, _func_graph.Outputs,
// monomorphic_function_utils._parse_func_attrs(forward_function_attr));
return (forward_function, _func_graph, backward_function, null, 0);
}


+ 84
- 0
src/TensorFlowNET.Core/Functions/TracingCompiler.cs View File

@@ -0,0 +1,84 @@
using System;
using System.Collections.Generic;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using Tensorflow.Graphs;

namespace Tensorflow.Functions
{
public class TracingCompiler
{
Func<Tensor[], Tensor[]> _csharp_function;
//FunctionSpec _function_spec;
internal string _name;
bool _autograph;
Dictionary<string, ConcreteFunction> _function_cache;
Dictionary<string, AttrValue> _function_attributes;
int _tracing_count;
public TracingCompiler(Func<Tensor[], Tensor[]> csharp_function, string name, object? input_signatures = null,
Dictionary<string, AttrValue> attributes = null, bool autograph = true, object? autograph_options = null,
bool reduce_retracing = false, bool capture_by_value = false)
{
_csharp_function = csharp_function;
bool pure_function = attributes is not null && attributes.Count > 0 && attributes.ContainsKey(monomorphic_function_utils.IMPLEMENTS_ATTRIBUTE_NAME);
_name = name;
_autograph = autograph;
_function_attributes = attributes ?? new Dictionary<string, AttrValue>();
_function_cache = new Dictionary<string, ConcreteFunction>();
_tracing_count = 0;
}

public Tensor[] Apply(Tensor[] inputs)
{
// TODO(Rinne): add lock here.
var (concrete_function, filtered_flat_args) = _maybe_define_function(inputs);
return concrete_function.CallFlat(filtered_flat_args, concrete_function.CapturedInputs);
}

internal ConcreteFunction _get_concrete_function_internal_garbage_collected(Tensor[] args)
{
var (concrete_function, _) = _maybe_define_concrete_function(args);
return concrete_function;
}

private (ConcreteFunction, Tensor[]) _maybe_define_concrete_function(Tensor[] args)
{
return _maybe_define_function(args);
}

private (ConcreteFunction, Tensor[]) _maybe_define_function(Tensor[] args)
{
var lookup_func_key = male_cache_key(args);
if(_function_cache.TryGetValue(lookup_func_key, out var concrete_function))
{
return (concrete_function, args);
}
concrete_function = _create_concrete_function(args);
_function_cache[lookup_func_key] = concrete_function;
return (concrete_function, args);
}

private ConcreteFunction _create_concrete_function(Tensor[] args)
{
_tracing_count++;
int arglen = args.Length;
var concrete_function = new ConcreteFunction(FuncGraph.func_graph_from_func(
_name, x => _csharp_function(x.Where(y => y is Tensor).Select(y => (Tensor)y).ToArray()),
args, new Dictionary<string, object>(), autograph: _autograph
), _function_attributes);
return concrete_function;
}

private static string male_cache_key(Tensor[] inputs)
{
string res = "";
foreach (var input in inputs)
{
res += $"{input.name}_{input.Id}";
}
return res;
}
}
}

+ 5
- 1
src/TensorFlowNET.Core/Functions/c_api.function.cs View File

@@ -16,6 +16,7 @@

using System;
using System.Runtime.InteropServices;
using Tensorflow.Functions;

namespace Tensorflow
{
@@ -54,6 +55,9 @@ namespace Tensorflow
public static extern IntPtr TF_FunctionName(SafeFuncGraphHandle func);

[DllImport(TensorFlowLibName)]
public static extern void TF_GraphCopyFunction(SafeGraphHandle g, SafeFuncGraphHandle func, IntPtr grad, SafeStatusHandle status);
public static extern void TF_GraphCopyFunction(SafeGraphHandle g, SafeFuncGraphHandle func, SafeFuncGraphHandle grad, SafeStatusHandle status);

[DllImport(TensorFlowLibName)]
public static extern int TF_GraphGetFunctions(SafeGraphHandle g, IntPtr[] funcs, int max_func, SafeStatusHandle status);
}
}

+ 50
- 0
src/TensorFlowNET.Core/Functions/composite_tensor_utils.cs View File

@@ -0,0 +1,50 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Framework;
using Tensorflow.Framework.Models;
using Tensorflow.Util;

namespace Tensorflow.Functions
{
internal static class composite_tensor_utils
{
public static List<object> flatten_with_variables(object inputs)
{
List<object> flat_inputs = new();
foreach(var value in nest.flatten(inputs))
{
if(value is CompositeTensor && !resource_variable_ops.is_resource_variable(value))
{
throw new NotImplementedException("The composite tensor has not been fully supported.");
}
else
{
flat_inputs.Add(value);
}
}
return flat_inputs;
}
public static List<object> flatten_with_variables_or_variable_specs(object arg)
{
List<object> flat_inputs = new();
foreach(var value in nest.flatten(arg))
{
if(value is CompositeTensor && !resource_variable_ops.is_resource_variable(value))
{
throw new NotImplementedException("The composite tensor has not been fully supported.");
}
// TODO(Rinne): deal with `VariableSpec`.
else if(value is TypeSpec type_spec && value is not TensorSpec)
{
throw new NotImplementedException("The TypeSpec has not been fully supported.");
}
else
{
flat_inputs.Add(value);
}
}
return flat_inputs;
}
}
}

+ 10
- 5
src/TensorFlowNET.Core/Functions/function_saved_model_utils.cs View File

@@ -34,11 +34,10 @@ namespace Tensorflow.Functions
"https://github.com/SciSharp/TensorFlow.NET/issues");
}
});
var bound_variables = inputs.TakeWhile(obj => obj is IVariableV1);
var bound_variables = inputs.Where(obj => obj is IVariableV1).Select(x => (IVariableV1)x);

List<Tensor> captured_inputs_list = new();
// TODO(Rinne): concrete_function.set_variables(bound_variables)

concrete_function.set_variables(bound_variables);

if (bound_inputs is not null)
{
@@ -54,8 +53,14 @@ namespace Tensorflow.Functions
concrete_function.func_graph.replace_capture(bound_input, internal_capture);
if(internal_capture.dtype == dtypes.resource)
{
// skip the check of variable.
handle_data_util.copy_handle_data(bound_input, internal_capture);
if (resource_variable_ops.is_resource_variable(bound_input))
{
handle_data_util.copy_handle_data(bound_input.Handle, internal_capture);
}
else
{
handle_data_util.copy_handle_data(bound_input, internal_capture);
}
}
concrete_function.func_graph.capture(bound_input);
}


+ 248
- 20
src/TensorFlowNET.Core/Functions/monomorphic_function.cs View File

@@ -1,20 +1,137 @@
using System;
using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Eager;
using Tensorflow.Framework.Models;
using Tensorflow.Gradients;
using Tensorflow.Graphs;
using Tensorflow.Common.Extensions;
using Tensorflow.Operations;
using Tensorflow.Framework;
using static Tensorflow.Binding;
using System.Diagnostics;

namespace Tensorflow.Functions
{
public class DelayedRewriteGradientFunctions: TapeGradientFunctions
internal static class monomorphic_function_utils
{
internal static string _FORWARD_PREFIX = "__forward_";
internal static string _BACKWARD_PREFIX = "__backward_";
internal static string _INFERENCE_PREFIX = "__inference_";
internal static string IMPLEMENTS_ATTRIBUTE_NAME = "_implements";
internal static string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name";
internal static string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name";
public static string _inference_name(string name)
{
return $"{_INFERENCE_PREFIX}{name}_{ops.uid()}";
}
public static string _forward_name(string name)
{
return $"{_FORWARD_PREFIX}{name}_{ops.uid()}";
}
public static string _backward_name(string name)
{
return $"{_BACKWARD_PREFIX}{name}_{ops.uid()}";
}

public static (EagerDefinedFunction, ConcreteFunction) _create_forward_backward_with_graph(Dictionary<string, AttrValue> attrs,
FuncGraph forward_graph, FuncGraph backwards_graph)
{
string forward_function_name = _forward_name(forward_graph.Name);
Dictionary<string, AttrValue> common_attributes;
if(attrs is null)
{
common_attributes = new Dictionary<string, AttrValue>();
}
else
{
common_attributes = new Dictionary<string, AttrValue>(attrs);
}

if (common_attributes.ContainsKey(IMPLEMENTS_ATTRIBUTE_NAME))
{
common_attributes.Remove(IMPLEMENTS_ATTRIBUTE_NAME);
}
var backward_function_attr = _parse_func_attrs(new Dictionary<string, object>()
{
{FORWARD_FUNCTION_ATTRIBUTE_NAME, forward_function_name }
});
backward_function_attr.Update(common_attributes);
var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr);
var forward_function_attr = _parse_func_attrs(new Dictionary<string, object>()
{
{BACKWARD_FUNCTION_ATTRIBUTE_NAME, backward_function.Name }
});
forward_function_attr.Update(common_attributes);
var forward_function = new EagerDefinedFunction(forward_function_name, forward_graph,
forward_graph.Inputs, forward_graph.Outputs, forward_function_attr);
return (forward_function, backward_function);
}

public static Dictionary<string, AttrValue> _parse_func_attrs(Dictionary<string, object> attributes)
{
Dictionary<string, AttrValue> attrs = new();
foreach(var item in attributes)
{
var key = item.Key;
var value = item.Value;
if (value is AttrValue attr_value)
{
attrs[key] = attr_value;
}
else if (value is bool b)
{
attrs[key] = new AttrValue() { B = b };
}
else if (value is int i)
{
attrs[key] = new AttrValue() { I = i };
}
else if (value is float f)
{
attrs[key] = new AttrValue() { F = f };
}
else if(value is string s)
{
attrs[key] = new AttrValue() { S = ByteString.CopyFromUtf8(s) };
}
else if (value is byte[] bytes)
{
attrs[key] = new AttrValue() { S = ByteString.CopyFrom(bytes) };
}
else
{
throw new ValueError($"Attribute {key} must be bool, int, float, string, or " +
$"AttrValue. Got {value.GetType()}.");
}
}
return attrs;
}

public static Dictionary<string, AttrValue> _parse_func_attrs(Dictionary<string, string> attributes)
{
Dictionary<string, AttrValue> attrs = new();
foreach (var item in attributes)
{
var key = item.Key;
var value = item.Value;
attrs[key] = new AttrValue() { S = ByteString.CopyFromUtf8(value) };
}
return attrs;
}
}
public class DelayedRewriteGradientFunctions : TapeGradientFunctions
{
EagerDefinedFunction _inference_function;
Dictionary<string, string> _attrs;
Dictionary<string, AttrValue> _attrs;
int _num_inference_outputs;
public DelayedRewriteGradientFunctions(FuncGraph func_graph, Dictionary<string, string> attrs)
:base(func_graph, false)
Dictionary<int, (EagerDefinedFunction, ConcreteFunction)> _cached_function_pairs = new();
public DelayedRewriteGradientFunctions(FuncGraph func_graph, Dictionary<string, AttrValue> attrs)
: base(func_graph, false)
{
_func_graph= func_graph;
_inference_function = new EagerDefinedFunction(_inference_name(_func_graph.Name),
_func_graph = func_graph;
_inference_function = new EagerDefinedFunction(monomorphic_function_utils._inference_name(_func_graph.Name),
_func_graph, _func_graph.Inputs, _func_graph.Outputs, attrs);
_attrs = attrs;
_num_inference_outputs = _func_graph.Outputs.Length;
@@ -22,7 +139,7 @@ namespace Tensorflow.Functions

public override EagerDefinedFunction Forward(Tensors inference_args = null, Tensors input_tangents = null)
{
if(input_tangents is not null)
if (input_tangents is not null)
{
throw new InvalidArgumentError($"unexpectedly got forwardprop information in " +
$"a class that does not support forwardprop.");
@@ -32,23 +149,134 @@ namespace Tensorflow.Functions

public override void Record(Tensors flat_outputs, Tensors inference_args)
{
// TODO(Rinne): implement it.
throw new NotImplementedException();
base.Record(flat_outputs, inference_args);
var (backward_function, to_record) = _backward(flat_outputs);
foreach(var tape in tf.GetTapeSet())
{
tape.RecordOperation(_inference_function.Signature.Name, to_record,
inference_args.Select(t => new TapeTensor(t)).ToArray(), backward_function);
}
}

//private (BackwardFunction, Tensors) _backward(Tensors outputs)
//{
// Tensor[] backward_function(Tensor[] grads, long[] unneeded_gradients)
// {
// var call_op = outputs[0].op;
public (EagerDefinedFunction, ConcreteFunction) forward_backward(int num_doutputs = -2)
{
if(num_doutputs == -2)
{
num_doutputs = _num_inference_outputs;
}
if(_cached_function_pairs.TryGetValue(num_doutputs, out var target))
{
return target;
}
var (forward, backward) = _construct_forward_backward(num_doutputs);
_cached_function_pairs[num_doutputs] = (forward, backward);
return (forward, backward);

// }
//}
}

private string _inference_name(string name)
private (BackwardFunction, Tensors) _backward(Tensors outputs)
{
return $"{_INFERENCE_PREFIX}{name}_{ops.uid()}";
Tensor[] backward_function(Tensor[] args, long[] unneeded_gradients)
{
var call_op = outputs[0].op;
return _rewrite_forward_and_call_backward(call_op, args);
}
return (backward_function, outputs);
}

internal Tensor[] _rewrite_forward_and_call_backward(Operation op, params object[] doutputs)
{
var (forward_function, backward_function) = forward_backward(doutputs.Length);
if(backward_function.Outputs is null || backward_function.Outputs.Length == 0)
{
return backward_function.FlatStructuredOutputs;
}
forward_function.AddToGraph(op.graph);

op._set_func_attr("f", forward_function.Name);
op._set_type_list_attr("Tout", forward_function.OutputTypes);
op._add_outputs(forward_function.OutputTypes.Select(x => x.as_tf_dtype()).
Skip(op.outputs.Length).ToArray(), forward_function.OutputShapes.Skip(op.outputs.Length).ToArray()
);
for(int i = 0; i < op.outputs.Length; i++)
{
var func_graph_output = forward_function._func_graph_outputs[i];
handle_data_util.copy_handle_data(func_graph_output, op.outputs[i]);
}

var capture_mapping = zip(_func_graph.Outputs.Select(t => ops.tensor_id(t)), op.outputs).
ToDictionary(x => x.Item1, x => x.Item2);
var remapped_captures = backward_function.CapturedInputs.Select(
x => capture_mapping.GetOrDefault(ops.tensor_id(x), x)
);

List<Tensor> cleaned_doutputs = new();
foreach(var (doutput, placeholder) in zip(doutputs, _func_graph.Outputs))
{
if (backprop_util.IsTrainable(placeholder))
{
if(doutput is IndexedSlices)
{
cleaned_doutputs.Add(ops.convert_to_tensor(doutput));
}
else if(doutput is null)
{
cleaned_doutputs.Add(default_gradient.zeros_like(placeholder));
}
else if(doutput is Tensor tensor)
{
cleaned_doutputs.Add(tensor);
}
else
{
throw new ValueError($"Unsupported type {doutput.GetType()} in function _rewrite_forward_and_call_backward");
}
}
}

return backward_function.CallFlat(cleaned_doutputs.ToArray(), remapped_captures.ToArray());
}

private (EagerDefinedFunction, ConcreteFunction) _construct_forward_backward(int num_doutputs)
{
var trainable_outputs = _func_graph.Outputs.Take(num_doutputs).Where(x => backprop_util.IsTrainable(x));

List<TensorSpec> signature = new();
foreach(var t in trainable_outputs)
{
var (shape, dtype) = default_gradient.shape_and_dtype(t);
signature.Add(new TensorSpec(shape, dtype));
}

Tensor[] _backprop_function(Tensor[] grad_ys)
{
return gradients_util._GradientsHelper(trainable_outputs.ToArray(), _func_graph.Inputs,
grad_ys, src_graph: _func_graph);
}

_func_graph.as_default();
FuncGraph backwards_graph = new(monomorphic_function_utils._backward_name(_func_graph.Name));
FuncGraph.func_graph_from_func(backwards_graph.Name, x => _backprop_function(x.Select(y =>
{
Debug.Assert(y is Tensor);
return (Tensor)y;
}).ToArray()), new object[0], new Dictionary<string, object>(), signature.ToArray(), backwards_graph);
var backwards_graph_captures = backwards_graph.external_captures;
var captures_from_forward = backwards_graph_captures.Where(c => c is not EagerTensor && c.graph == _func_graph);
HashSet<Tensor> existing_outputs = new HashSet<Tensor>(_func_graph.Outputs);
foreach(var capture in captures_from_forward)
{
if (!existing_outputs.Contains(capture))
{
existing_outputs.Add(capture);
_func_graph.Outputs.Add(capture);
}
}

var (forward_function, backward_function) = monomorphic_function_utils._create_forward_backward_with_graph(
_attrs, _func_graph, backwards_graph);
_func_graph.Exit();
return (forward_function, backward_function);
}
}
}

+ 52
- 0
src/TensorFlowNET.Core/Gradients/default_gradient.cs View File

@@ -0,0 +1,52 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Gradients
{
internal static class default_gradient
{
public static (Shape, TF_DataType) shape_and_dtype(Tensor t)
{
if(t.dtype == dtypes.resource)
{
var handle_data = resource_variable_ops.get_eager_safe_handle_data(t);
if(handle_data is null || !handle_data.IsSet || handle_data.ShapeAndType.Count != 1)
{
throw new ValueError($"Internal error: Tried to take gradients (or similar) " +
$"of a variable without handle data:\n{t}");
}
return (new Shape(handle_data.ShapeAndType[0].Shape), handle_data.ShapeAndType[0].Dtype.as_tf_dtype());
}
return (t.shape, t.dtype);
}

public static Tensor zeros_like(Tensor t)
{
if(t.dtype == dtypes.resource)
{
var (shape, dtype) = shape_and_dtype(t);
return array_ops.zeros(shape, dtype);
}
else
{
return array_ops.zeros_like(t);
}
}

public static TF_DataType get_zeros_dtype(Tensor t)
{
if(t.dtype == dtypes.resource)
{
var handle_data = resource_variable_ops.get_eager_safe_handle_data(t);
if(handle_data is null || !handle_data.IsSet || handle_data.ShapeAndType.Count != 1)
{
throw new ValueError($"Internal error: Tried to take gradients (or similar) " +
$"of a variable without handle data:\n{t}");
}
return handle_data.ShapeAndType[0].Dtype.as_tf_dtype();
}
return t.dtype;
}
}
}

+ 91
- 4
src/TensorFlowNET.Core/Gradients/gradients_util.cs View File

@@ -14,10 +14,15 @@
limitations under the License.
******************************************************************************/

using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Functions;
using Tensorflow.Gradients;
using Tensorflow.Graphs;
using Tensorflow.Operations;
using Tensorflow.Operations.ControlFlows;
using static Tensorflow.Binding;

@@ -148,7 +153,7 @@ namespace Tensorflow
Tensor[] in_grads = null;
Func<Operation, Tensor[], Tensor[]> grad_fn = null;
var is_partitioned_call = _IsPartitionedCall(op);
var is_func_call = false;
var is_func_call = src_graph.IsFunction(op.type) || is_partitioned_call;
var has_out_grads = out_grads.Exists(x => x != null);
if (has_out_grads && !stop_ops.Contains(op))
{
@@ -162,14 +167,41 @@ namespace Tensorflow
{
if (is_func_call)
{
EagerDefinedFunction func_call = null;
if (is_partitioned_call)
{

var func_attr = op.get_attr("f");
Debug.Assert(func_attr is NameAttrList);
var func_name = ((NameAttrList)func_attr).Name;
func_call = src_graph._get_function(func_name);
if(func_call is null && src_graph.OuterGraph is not null)
{
var graph = src_graph.OuterGraph;
while(graph is not null)
{
func_call = graph._get_function(func_name);
if(func_call is not null)
{
break;
}
if(graph.OuterGraph is not null)
{
graph = graph.OuterGraph;
}
else
{
break;
}
}
}
}
else
{

func_call = src_graph._get_function(op.type);
}
// skip the following codes:
// `func_call = getattr(op, "__defun", func_call)`
grad_fn = func_call.csharp_grad_func;
}
else
{
@@ -213,6 +245,8 @@ namespace Tensorflow
}
else
{
in_grads = _MaybeCompile(grad_scope, op, out_grads.Where(x => x != null).Select(x => x[0]).ToArray(),
null, (x, y) => _SymGrad(x, y));
throw new NotImplementedException("lambda: _SymGrad(op, out_grads)");
}
_VerifyGeneratedGradients(in_grads, op);
@@ -668,6 +702,36 @@ namespace Tensorflow
dtypes.resource, dtypes.variant}.Contains(dtype);
}

public static int PossibleTapeGradientTypes(Tensor[] tensors)
{
var tape_set = tf.GetTapeSet();
bool some_tape_watching = false;
if(tape_set is not null && tape_set.Count > 0)
{
foreach(var tape in tape_set)
{
if (tape.ShouldRecord(tensors))
{
if(tape.Persistent || some_tape_watching)
{
return POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER;
}
some_tape_watching = true;
}
}
}
// skip the forward_accumulators.

if (some_tape_watching)
{
return POSSIBLE_GRADIENT_TYPES_FIRST_ORDER;
}
else
{
return POSSIBLE_GRADIENT_TYPES_NONE;
}
}

/// <summary>
/// Return true if op has real gradient.
/// </summary>
@@ -688,7 +752,7 @@ namespace Tensorflow

private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func<Operation, Tensor[], Tensor[]> grad_fn)
{
scope = scope.EndsWith("/") ? scope.Substring(0, scope.Length - 1) : scope;
//scope = scope.TrimEnd('/').Replace('/', '_');
return grad_fn(op, out_grads);
}

@@ -701,5 +765,28 @@ namespace Tensorflow
throw new ValueError($"Num gradients {grads.Length} generated for op {op.node_def} do not match num " +
$"inputs {op.inputs._inputs.Count()}");
}

private static Tensor[] _SymGrad(Operation op, Tensor[] out_grads)
{
var f_in = ((Tensor[])op.inputs).Concat(out_grads).ToArray();
var f_types = ((Tensor[])op.inputs).Select(x => default_gradient.get_zeros_dtype(x)).ToArray();
NameAttrList f = new();
if (_IsPartitionedCall(op))
{
var func_attr = op.get_attr("f");
Debug.Assert(func_attr is NameAttrList);
f.Name = ((NameAttrList)func_attr).Name;
}
else
{
f.Name = op.type;
}
foreach(var k in op.node_def.Attr.Keys)
{
f.Attr[k] = AttrValue.Parser.ParseFrom(op.node_def.Attr[k].ToByteArray());
}
var in_grads = gen_functional_ops.symbolic_gradient(f_in, f_types, f);
return in_grads;
}
}
}

+ 15
- 4
src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs View File

@@ -98,12 +98,23 @@ namespace Tensorflow
{
if (op.inputs == null) return null;

RegisterFromAssembly();
var gradient_function = op._gradient_function;
if(gradient_function is null)
{
RegisterFromAssembly();

if (!gradientFunctions.ContainsKey(op.type))
throw new LookupError($"can't get graident function through get_gradient_function {op.type}");

if (!gradientFunctions.ContainsKey(op.type))
throw new LookupError($"can't get graident function through get_gradient_function {op.type}");
return gradientFunctions[op.type];
}

return gradientFunctions[op.type];
Tensor[] wrapped_gradient_function(Operation operation, Tensor[] args)
{
return gradient_function(operation, args);
}
// TODO(Rinne): check if this needs to be registered.
return wrapped_gradient_function;
}
}
}

+ 345
- 10
src/TensorFlowNET.Core/Graphs/FuncGraph.cs View File

@@ -1,6 +1,15 @@
using Google.Protobuf;
using System;
using System.Buffers;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Eager;
using Tensorflow.Exceptions;
using Tensorflow.Framework;
using Tensorflow.Framework.Models;
using Tensorflow.Functions;
using Tensorflow.Operations;
using Tensorflow.Util;
using static Tensorflow.Binding;

namespace Tensorflow.Graphs;
@@ -11,12 +20,65 @@ namespace Tensorflow.Graphs;
public class FuncGraph : Graph, IDisposable
{
internal SafeFuncGraphHandle _func_graph_handle;
internal HashSet<Tensor> _resource_tensor_inputs;
internal HashSet<WeakReference<IVariableV1>> _watched_variables;
internal IEnumerable<WeakReference<IVariableV1>> _weak_variables;
internal object[] _structured_outputs;
internal Dictionary<long, string> _output_names;
public string FuncName => _graph_key;

public Tensors Inputs { get; set; } = new Tensors();
public Tensors Outputs { get; set; } = new Tensors();
public Tensors FlatStructuredOutputs
{
get
{
List<Tensor> res = new();
foreach(var obj in _structured_outputs)
{
if(obj is Tensor tensor)
{
res.Add(tensor);
}
else if(obj is IEnumerable<Tensor> tensors)
{
res.AddRange(tensors);
}
else
{
throw new TypeError("The structured outputs member should be tensor or tensors.");
}
}
return res;
}
}
public string Name { get; set; }
public Dictionary<string, string> Attrs { get; set; }
public IEnumerable<IVariableV1> Variables
{
get
{
return _weak_variables.Select(v =>
{
if (v.TryGetTarget(out var target))
{
return target;
}
else
{
throw new AssertionError("Called a function referencing variables which have been deleted. " +
"This likely means that function-local variables were created and " +
"not referenced elsewhere in the program. This is generally a " +
"mistake; consider storing variables in an object attribute on first call.");
}
});
}
internal set
{
_weak_variables = value.Select(x => new WeakReference<IVariableV1>(x));
}
}
public IEnumerable<IVariableV1> TrainableVariables => Variables.Where(v => v.Trainable);
public Dictionary<string, AttrValue> Attrs { get; set; }

Dictionary<long, (Tensor, Tensor)> _captures
= new Dictionary<long, (Tensor, Tensor)>();
@@ -42,9 +104,12 @@ public class FuncGraph : Graph, IDisposable
outer_graph = outer_graph.OuterGraph;
_graph_key = Name = name;
building_function = true;
_weak_variables = new List<WeakReference<IVariableV1>>();
_resource_tensor_inputs = new HashSet<Tensor>();
_watched_variables = new HashSet<WeakReference<IVariableV1>>();
}

public FuncGraph(SafeGraphHandle handle, string name, Dictionary<string, string> attrs) : base()
public FuncGraph(SafeGraphHandle handle, string name, Dictionary<string, AttrValue> attrs) : base()
{
outer_graph = ops.get_default_graph();
while (outer_graph.building_function)
@@ -55,6 +120,9 @@ public class FuncGraph : Graph, IDisposable
// Will to test if FuncGraph has memory leak
// c_api.TF_DeleteGraph(_handle);
_handle = handle;
_weak_variables = new List<WeakReference<IVariableV1>>();
_resource_tensor_inputs = new HashSet<Tensor>();
_watched_variables = new HashSet<WeakReference<IVariableV1>>();
}

public void replace_capture(Tensor tensor, Tensor placeholder)
@@ -62,14 +130,14 @@ public class FuncGraph : Graph, IDisposable
_captures[tensor.Id] = (tensor, placeholder);
}

public void ToGraph(Operation[] opers,
public unsafe void ToGraph(Operation[] opers,
Tensor[] inputs, Tensor[] outputs,
string[] output_names)
{
var status = new Status();
if (output_names != null && output_names.Length == 0)
if (output_names is null)
{
output_names = null;
output_names = new string[0];
};

_func_graph_handle = c_api.TF_GraphToFunction(_handle,
@@ -81,7 +149,7 @@ public class FuncGraph : Graph, IDisposable
inputs.Select(x => new TF_Output(x.op, 0)).ToArray(),
outputs.Length,
outputs.Select(x => new TF_Output(x.op, 0)).ToArray(),
output_names,
output_names.Length != outputs.Length ? null : output_names,
IntPtr.Zero,
null,
status);
@@ -211,6 +279,19 @@ public class FuncGraph : Graph, IDisposable
Inputs.Add(placeholder);
}

Tensor pop_capture(Tensor tensor)
{
if(_captures.TryGetValue(tensor.Id, out var capture))
{
_captures.Remove(tensor.Id);
return capture.Item2;
}
else
{
return null;
}
}

Tensor _create_substitute_placeholder(Tensor value,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,
@@ -234,10 +315,7 @@ public class FuncGraph : Graph, IDisposable

foreach (var (_name, attr_value) in enumerate(Attrs))
{
var serialized = new AttrValue
{
S = ByteString.CopyFromUtf8(attr_value)
}.ToByteArray();
var serialized = attr_value.ToByteArray();
c_api.TF_FunctionSetAttrValueProto(_func_graph_handle, _name, serialized, serialized.Length, tf.Status);
tf.Status.Check(true);
}
@@ -260,4 +338,261 @@ public class FuncGraph : Graph, IDisposable
{
c_api.TFE_ContextRemoveFunction(tf.Context, _graph_key, tf.Status);
}

public static FuncGraph func_graph_from_func(string name, Func<object[], object[]> func,
object[] args, Dictionary<string, object> kwargs, TensorSpec[] signature = null,
FuncGraph func_graph = null, bool autograph = false, object autograph_options = null,
bool add_control_dependencies = true, string[] arg_names = null,
Tensor op_return_value = null, bool capture_by_value = false,
bool acd_record_initial_resource_uses = false)
{
if(func_graph is null)
{
func_graph = new FuncGraph(name);
}

// TODO(Rinne): deal with control dependencies.

func_graph.as_default();
var current_scope = variable_scope.get_variable_scope();
var default_use_resource = current_scope.use_resource;
current_scope.use_resource = true;

if(signature is not null)
{
args = signature;
kwargs = new Dictionary<string, object>();
}
var func_args = _get_defun_inputs_from_args(args, arg_names);
var func_kwargs = _get_defun_inputs_from_kwargs(kwargs);

if(func_kwargs is not null && func_kwargs.Count > 0)
{
throw new NotImplementedException("The keyword args has not been supported in `func_graph_from_func`.");
}

foreach(var arg in nest.flatten<object>(new object[] { func_args, func_kwargs }))
{
if(arg is Tensor tensor && tensor.dtype == dtypes.resource)
{
func_graph._resource_tensor_inputs.Add(tensor);
}
else if (arg is ResourceVariable variable)
{
func_graph._resource_tensor_inputs.Add(variable.Handle);
}
}

// skip the assignment of `func_graph.structured_input_signature`.

var flat_func_args = nest.flatten(func_args as object);
var flat_func_kwargs = nest.flatten(func_kwargs as object);
func_graph.Inputs = new Tensors(flat_func_args.concat(flat_func_kwargs)
.Where(x => x is Tensor).Select(x => (Tensor)x));

//var func_args_before = nest.pack_sequence_as(func_args, flat_func_args, true);
//var func_kwargs_before = nest.pack_sequence_as(func_kwargs, flat_func_kwargs, true);

Tensor convert(object x)
{
if (x is null) return null;
Tensor res = null;
if(op_return_value is not null && x is Operation)
{
tf_with(ops.control_dependencies(new object[] { x }), _ =>
{
res = array_ops.identity(op_return_value);
});
}
else if(x is not TensorArray)
{
Debug.Assert(x is Tensor);
res = ops.convert_to_tensor_or_composite(x as Tensor);
}
else
{
throw new NotImplementedException($"The `TensorArray` is not supported here currently.");
}
if (add_control_dependencies)
{
// TODO(Rinne): `x = deps_ctx.mark_as_return(x)`.
}
return res;
}

if (autograph)
{
throw new NotImplementedException("The autograph of `func_graph_from_func` has not been supported.");
}

var func_outputs = func(func_args);
func_outputs = variable_utils.convert_variables_to_tensors(func_outputs);
func_outputs = func_outputs.Select(x => convert(x)).ToArray();
// TODO(Rinne): `check_func_mutation`.

current_scope.use_resource = default_use_resource;

var graph_variables = func_graph._watched_variables.ToList();
HashSet<IVariableV1> arg_variables = new HashSet<IVariableV1>();
List<Tensor> inputs = new();
foreach(var arg in composite_tensor_utils.flatten_with_variables(func_args))
{
if(arg is BaseResourceVariable variable)
{
var resource_placeholder = func_graph.pop_capture(variable.Handle);
if(resource_placeholder is null)
{
continue;
}
Debug.Assert(variable is IVariableV1);
arg_variables.Add(variable as IVariableV1);
inputs.Add(resource_placeholder);
}
else if(arg is Tensor tensor)
{
inputs.Add(tensor);
}
}
var variables = graph_variables.Select(v =>
{
if (v.TryGetTarget(out var target))
{
return target;
}
else
{
return null;
}
}).Where(v => v is not null && !arg_variables.Contains(v));
func_graph.Inputs = inputs.Concat(func_graph.internal_captures).ToArray();
func_graph._structured_outputs = func_outputs;
func_graph.Outputs.AddRange(func_graph.FlatStructuredOutputs.Where(x => x is not null)
.Select(x => func_graph.capture(x)));

func_graph.Variables = variables;

func_graph.Exit();

if (add_control_dependencies)
{
// TODO(Rinne): implement it.
}
return func_graph;
}

private static object[] _get_defun_inputs_from_args(object[] args, string[] names)
{
return _get_defun_inputs(args, names, args) as object[];
}

private static Dictionary<string, object> _get_defun_inputs_from_kwargs(Dictionary<string, object> kwargs)
{
// TODO(Rinne): implement it.
Debug.Assert(kwargs is null || kwargs.Count == 0);
return kwargs;
//string[] names;
//object[] args;
//if(kwargs is not null && kwargs.Count > 0)
//{
// var sorted_kwargs = kwargs.OrderBy(x => x.Key);
// names = sorted_kwargs.Select(x => x.Key).ToArray();
// args = sorted_kwargs.Select(x => x.Value).ToArray();
//}
//else
//{
// names = new string[0];
// args = new object[0];
//}
//return _get_defun_inputs(args, names, kwargs) as Dictionary<string, object>;
}

private static object _get_defun_inputs(object[] args, string[] names, object structured_args)
{
List<object> function_inputs = new();
if(names is null)
{
names = new string[args.Length];
}

foreach(var (arg_value, name) in zip(args, names))
{
foreach(var val in composite_tensor_utils.flatten_with_variables_or_variable_specs(arg_value))
{
function_inputs.Add(_get_defun_input(val, name));
}
}
return nest.pack_sequence_as(structured_args, nest.flatten<object>(function_inputs), true);
}

private static object _get_defun_input(object arg, string name)
{
var func_graph = ops.get_default_graph() as FuncGraph;
Debug.Assert(func_graph is not null);
if (arg is Tensor tensor)
{
Tensor placeholder;
try
{
placeholder = tf.placeholder(tensor.dtype, tensor.shape, name);
}
catch (ValueError)
{
// TODO(Rinne): Add warning here.
placeholder = tf.placeholder(tensor.dtype, tensor.shape);
}
handle_data_util.copy_handle_data(tensor, placeholder);
if (name is not null)
{
placeholder.op._set_attr("_user_specified_name", new AttrValue()
{
S = tf.compat.as_bytes(name)
});
}
return placeholder;
}
else if (arg is TensorSpec spec)
{
string requested_name;
if (!string.IsNullOrEmpty(spec.name))
{
requested_name = spec.name;
}
else
{
requested_name = name;
}
Tensor placeholder;
try
{
placeholder = tf.placeholder(spec.dtype, spec.shape, requested_name);
}
catch (ValueError)
{
// TODO(Rinne): Add warning here.
placeholder = tf.placeholder(spec.dtype, spec.shape);
}
if (name is not null)
{
placeholder.op._set_attr("_user_specified_name", new AttrValue()
{
S = tf.compat.as_bytes(requested_name)
});
}
return placeholder;
}
else if (arg is BaseResourceVariable variable)
{
var placeholder = func_graph.capture(variable.Handle, name);
placeholder.op._set_attr("_user_specified_name", new AttrValue()
{
S = tf.compat.as_bytes(name)
});
return arg;
}
// TODO(Rinne): deal with `VariableSpec`.
else
{
return arg;
}
}
}

+ 8
- 1
src/TensorFlowNET.Core/Graphs/Graph.Gradient.cs.cs View File

@@ -1,4 +1,6 @@
namespace Tensorflow
using Tensorflow.Graphs;

namespace Tensorflow
{
public partial class Graph
{
@@ -6,5 +8,10 @@
{

}

internal GraphOverrideGradientContext _override_gradient_function(Dictionary<string, Func<Operation, object[], Tensor[]>> gradient_function_map)
{
return new GraphOverrideGradientContext(this, gradient_function_map);
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Graphs/Graph.Operation.cs View File

@@ -118,7 +118,7 @@ namespace Tensorflow
/// <param name="compute_device">(Optional.) If True, device functions will be executed
/// to compute the device property of the Operation.</param>
/// <returns>An `Operation` object.</returns>
public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true)
public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true, OperationDescription desc = null)
{
var ret = new Operation(c_op, this);
_add_op(ret);


+ 29
- 2
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -21,6 +21,7 @@ using System.Collections.Specialized;
using System.Linq;
using Tensorflow.Framework;
using Tensorflow.Functions;
using Tensorflow.Common.Extensions;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -88,6 +89,7 @@ namespace Tensorflow
private List<Operation> _unfetchable_ops = new List<Operation>();
private List<Tensor> _unfeedable_tensors = new List<Tensor>();
private Dictionary<string, EagerDefinedFunction> _functions = new();
internal Dictionary<string, Func<Operation, object[], Tensor[]>> _gradient_function_map = new();
private VersionDef _graph_def_versions = new VersionDef()
{
Producer = versions.GRAPH_DEF_VERSION,
@@ -161,13 +163,30 @@ namespace Tensorflow
return _functions.ContainsKey(tf.compat.as_str(name));
}

public void AddFunction(EagerDefinedFunction function)
internal void AddFunction(EagerDefinedFunction function)
{
_check_not_finalized();

var name = function.Name;
if(function._grad_func_name is not null && function.csharp_grad_func is not null)
{
throw new ValueError($"Gradient defined twice for function {name}");
}

// TODO(Rinne): deal with c_graph
var c_graph = this.c_graph;
var func = function._c_func.Get();
Status status = new();
if (function._grad_func is not null)
{
var gradient = function._grad_func._c_func.Get();
c_api.TF_GraphCopyFunction(c_graph, func, gradient, status);
status.Check(true);
}
else
{
c_api.TF_GraphCopyFunction(c_graph, func, new SafeFuncGraphHandle(IntPtr.Zero), status);
status.Check(true);
}

_functions[tf.compat.as_str(name)] = function;

@@ -332,6 +351,9 @@ namespace Tensorflow

private void _create_op_helper(Operation op, bool compute_device = true)
{
// high priority
// TODO(Rinne): complete the implementation.
op._gradient_function = _gradient_function_map.GetOrDefault(op.type, null);
_record_op_seen_by_control_dependencies(op);
}

@@ -548,6 +570,11 @@ namespace Tensorflow
ops.pop_graph();
}

internal EagerDefinedFunction _get_function(string name)
{
return _functions.GetOrDefault(name, null);
}

string debugString = string.Empty;
public override string ToString()
{


+ 37
- 0
src/TensorFlowNET.Core/Graphs/GraphOverrideGradientContext.cs View File

@@ -0,0 +1,37 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;

namespace Tensorflow.Graphs
{
internal class GraphOverrideGradientContext: ITensorFlowObject
{
Graph _graph;
Dictionary<string, Func<Operation, object[], Tensor[]>> _new_gradient_function_map;
public GraphOverrideGradientContext(Graph graph,
Dictionary<string, Func<Operation, object[], Tensor[]>> new_gradient_function_map)
{
_graph = graph;
_new_gradient_function_map = new_gradient_function_map;
}

[DebuggerStepThrough]
public void __enter__()
{
Debug.Assert(_graph._gradient_function_map.Count == 0);
_graph._gradient_function_map = _new_gradient_function_map;
}

[DebuggerStepThrough]
public void __exit__()
{
_graph._gradient_function_map = new Dictionary<string, Func<Operation, object[], Tensor[]>>();
}

public void Dispose()
{

}
}
}

+ 94
- 13
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -20,6 +20,9 @@ using System.Collections.Generic;
using System.Linq;
using Tensorflow.Util;
using static Tensorflow.Binding;
using Google.Protobuf;
using Google.Protobuf.WellKnownTypes;
using System.Diagnostics;

namespace Tensorflow
{
@@ -47,6 +50,8 @@ namespace Tensorflow

private readonly Graph _graph;

internal Func<Operation, object[], Tensor[]> _gradient_function;

public string type => OpType;

public Graph graph => _graph;
@@ -61,7 +66,7 @@ namespace Tensorflow

public string Device => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationDevice(_handle));

// OperationDescription _opDesc;
//private OperationDescription _op_desc;

public NodeDef node_def => GetNodeDef();

@@ -216,21 +221,19 @@ namespace Tensorflow

var x = AttrValue.Parser.ParseFrom(buf.ToArray());

string oneof_value = x.ValueCase.ToString();
if (string.IsNullOrEmpty(oneof_value))
return null;
var oneof_value = x.ValueCase;
if (oneof_value == AttrValue.ValueOneofCase.None)
return new object[0];

switch (oneof_value.ToLower())
if(oneof_value == AttrValue.ValueOneofCase.List)
{
case "list":
throw new NotImplementedException($"Unsupported field type in {oneof_value}");
case "type":
return x.Type;
case "s":
return x.S.ToStringUtf8();
default:
return x.GetType().GetProperty(oneof_value).GetValue(x);
throw new NotImplementedException($"Unsupported field type in {oneof_value}");
}
if(oneof_value == AttrValue.ValueOneofCase.Type)
{
return dtypes.as_tf_dtype(x.Type);
}
return ProtoUtils.GetSingleAttrValue(x, oneof_value);
}

public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s)
@@ -309,5 +312,83 @@ namespace Tensorflow
}

public NDArray numpy() => throw new NotImplementedException("");

internal void _add_outputs(TF_DataType[] types, Shape[] shapes)
{
Debug.Assert(types.Length == shapes.Length);
int orig_num_outputs = this.outputs.Length;
//var new_outputs = new List<Tensor>(_outputs);

var old_outputs = _outputs;
_outputs = new Tensor[orig_num_outputs + types.Length];
for(int i = 0; i < orig_num_outputs; i++)
{
_outputs[i] = old_outputs[i];
}

// Since the `_outputs` is defined as `Array`, when we add new output, we
// have to create a new array, which brings some performance concerns.
// In the future maybe the type of `outputs` should be reconsidered.
for(int i = 0; i < types.Length; i++)
{
var t = new Tensor(this, orig_num_outputs + 1, types[i]);
_outputs[i] = t;
//t = tf.ensure_shape(t, shapes[i]);
t.shape = shapes[i];
//new_outputs.Add(t);
}
//_outputs = new_outputs.ToArray();
}

internal void _set_func_attr(string attr_name, string func_name)
{
var func = new NameAttrList() { Name = func_name };
_set_attr(attr_name, new AttrValue() { Func = func });
}

internal void _set_type_list_attr(string attr_name, DataType[] types)
{
if(types is null || types.Length == 0)
{
return;
}
var type_list = new AttrValue.Types.ListValue();
type_list.Type.AddRange(types);
_set_attr(attr_name, new AttrValue() { List = type_list });
}

internal void _set_attr(string attr_name, AttrValue attr_value)
{
var buffer = new Buffer(attr_value.ToByteArray());
try
{
_set_attr_with_buf(attr_name, buffer);
}
finally
{
buffer.Release();
}
}

internal void _set_attr_with_buf(string attr_name, Buffer attr_buf)
{
//if(_op_desc is null)
//{
// //var new_node_def = NodeDef.Parser.ParseFrom(node_def.ToByteArray());
// //new_node_def.Name += "_temp";
// //var op = new Operation(new_node_def, graph, inputs, _output_types, control_inputs, _input_types);
// //Status status = new();
// //c_api.TF_SetAttrBool(op._op_desc, "trainable", true);
// ////c_api.TF_SetAttrValueProto(op._op_desc, attr_name, attr_buf.ToArray(), attr_buf.Length, status);
// //status.Check(true);
// // TODO(Rinne): deal with it. Give a warning or make the Operation always contains `op_desc`.
//}
//else
//{
// //Status status = new();
// //c_api.TF_SetAttrValueProto(_op_desc, attr_name, attr_buf.ToArray(), attr_buf.Length, status);
// //status.Check(true);
//}
}
}
}

+ 4
- 4
src/TensorFlowNET.Core/Operations/c_api.ops.cs View File

@@ -208,9 +208,9 @@ namespace Tensorflow

[DllImport(TensorFlowLibName)]
public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
public static extern IntPtr GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output);
[DllImport(TensorFlowLibName)]
public static extern void SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data);
//[DllImport(TensorFlowLibName)]
//public static extern IntPtr GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output);
//[DllImport(TensorFlowLibName)]
//public static extern void SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data);
}
}

+ 1
- 1
src/TensorFlowNET.Core/Operations/functional_ops.cs View File

@@ -39,7 +39,7 @@ namespace Tensorflow

if (config is null)
{
config = function_utils.get_disabled_rewriter_config();
config = function_utils.get_disabled_rewriter_config().ToString();
}

if (executor_type is null)


+ 45
- 0
src/TensorFlowNET.Core/Operations/gen_functional_ops.cs View File

@@ -79,5 +79,50 @@ namespace Tensorflow.Operations

};
}

public static Tensor[] symbolic_gradient(Tensor[] input, TF_DataType[] Tout, NameAttrList f, string name = null)
{
var ctx = tf.Context;
if (ctx.executing_eagerly())
{
try
{
var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(
"SymbolicGradient", name, input, Tout, f));
return _result;
}
catch (Exception)
{

}

try
{
return symbolic_gradient_eager_fallback(input, Tout, f, name, ctx);
}
catch (Exception)
{

}
}
var op = tf.OpDefLib._apply_op_helper("SymbolicGradient", name, new object[] { input, Tout, f });
var result = op.outputs;
if (execute.must_record_gradient())
{
throw new NotImplementedException();
}
return result;
}

public static Tensor[] symbolic_gradient_eager_fallback(Tensor[] input, TF_DataType[] Tout, NameAttrList f, string name, Context ctx)
{
object[] attrs = new object[] { "Tin", input, "Tout", Tout, "f", f };
var result = execute.executes("SymbolicGradient", Tout.Length, input, attrs, ctx, name);
if (execute.must_record_gradient())
{
throw new NotImplementedException();
}
return result;
}
}
}

+ 38
- 0
src/TensorFlowNET.Core/Operations/gen_ops.cs View File

@@ -10050,13 +10050,51 @@ namespace Tensorflow.Operations
/// </remarks>
public static Tensor ensure_shape(Tensor input, Shape shape, string name = "EnsureShape")
{
var ctx = tf.Context;
if (ctx.executing_eagerly())
{
try
{
var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("EnsureShape", name, input, shape));
return _result[0];
}
catch (Exception)
{
Console.WriteLine();
}
try
{
return ensure_shape_eager_fallback(input, shape, name, ctx);
}
catch (Exception)
{
Console.WriteLine();
}
}

var dict = new Dictionary<string, object>();
dict["input"] = input;
dict["shape"] = shape;
var op = tf.OpDefLib._apply_op_helper("EnsureShape", name: name, keywords: dict);
if (execute.must_record_gradient())
{
throw new NotImplementedException();
}
return op.output;
}

public static Tensor ensure_shape_eager_fallback(Tensor input, Shape shape, string name, Context ctx)
{
object[] attrs = new object[4] { "shape", shape, "T", input.dtype.as_datatype_enum() };
var _result = execute.executes("EnsureShape", 1, new Tensor[] { input },
attrs, ctx, name);
if (execute.must_record_gradient())
{
throw new NotImplementedException();
}
return _result[0];
}

/// <summary>
/// Creates or finds a child frame, and makes <c>data</c> available to the child frame.
/// </summary>


+ 2
- 0
src/TensorFlowNET.Core/Operations/handle_data_util.cs View File

@@ -52,5 +52,7 @@ namespace Tensorflow.Operations
// TODO(Rinne): enable it. (currently the internal c api cannot be invoked.)
//c_api.SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), handle_data.ToByteArray());
}

public static HandleData get_resource_handle_data(Tensor graph_op) => ops.get_resource_handle_data(graph_op);
}
}

+ 46
- 25
src/TensorFlowNET.Core/Operations/resource_variable_ops.cs View File

@@ -24,6 +24,7 @@ using static Tensorflow.CppShapeInferenceResult.Types;
using static Tensorflow.Binding;
using Tensorflow.Operations;
using System.Buffers;
using Tensorflow.Eager;

namespace Tensorflow
{
@@ -41,12 +42,7 @@ namespace Tensorflow
name: name);
}

public static bool is_resource_variable(IVariableV1 var)
{
return var is BaseResourceVariable;
}
public static bool is_resource_variable(Trackable var)
public static bool is_resource_variable(object var)
{
return var is BaseResourceVariable;
}
@@ -138,10 +134,27 @@ namespace Tensorflow
/// <param name="graph_mode"></param>
internal unsafe static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode)
{
tensor.HandleData = handle_data;
if (!graph_mode)
return;

var size = handle_data.ShapeAndType.Count;

var shapes = new IntPtr[size];
var types = new DataType[size];
var ranks = new int[size];

for (int i = 0; i < size; i++)
{
var shapeAndType = handle_data.ShapeAndType[i];
types[i] = shapeAndType.Dtype;
ranks[i] = shapeAndType.Shape.UnknownRank ? -1 : shapeAndType.Shape.Dim.Count;
var dims = shapeAndType.Shape.Dim.Select(x => x.Size).ToArray();
}

//tensor.HandleData = handle_data;
//if (!graph_mode)
// return;

//var shapes = handle_data.ShapeAndType.Select(x => x.Shape);
//var types = handle_data.ShapeAndType.Select(x => x.Dtype).ToArray();
//var ranks = shapes.Select(s => s.UnknownRank ? -1 : s.Dim.Count).ToArray();
@@ -196,24 +209,6 @@ namespace Tensorflow
throw new NotImplementedException("");
}

private static HandleData get_eager_safe_handle_data(Tensor handle)
{
if (handle.Handle == null)
{
var data = new HandleData();
data.ShapeAndType.Add(new HandleShapeAndType
{
Shape = handle.shape.as_shape_proto(),
Dtype = handle.dtype.as_datatype_enum()
});
return data;
}
else
{
return HandleData.Parser.ParseFrom(handle.BufferToArray());
}
}

/// <summary>
/// Copies an existing variable to a new graph, with no initializer.
/// </summary>
@@ -281,5 +276,31 @@ namespace Tensorflow
}
}
}

public static HandleData get_eager_safe_handle_data(Tensor handle)
{
if (handle.Handle == null)
{
var data = new HandleData();
data.ShapeAndType.Add(new HandleShapeAndType
{
Shape = handle.shape.as_shape_proto(),
Dtype = handle.dtype.as_datatype_enum()
});
return data;
}
else
{
return HandleData.Parser.ParseFrom(handle.BufferToArray());
}
//if(handle is EagerTensor)
//{
// return handle.HandleData;
//}
//else
//{
// return handle_data_util.get_resource_handle_data(handle);
//}
}
}
}

+ 1
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -101,6 +101,7 @@ namespace Tensorflow
_op = op;
_value_index = value_index;
_override_dtype = dtype;
_tf_output = null;
_id = ops.uid();
}



+ 6
- 4
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -136,9 +136,9 @@ namespace Tensorflow
protected virtual void SetShapeInternal(Shape value)
{
if (value == null)
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status);
c_api.TF_GraphSetTensorShape(op.graph.c_graph, _as_tf_output(), null, -1, tf.Status);
else
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.dims, value.ndim, tf.Status);
c_api.TF_GraphSetTensorShape(op.graph.c_graph, _as_tf_output(), value.dims, value.ndim, tf.Status);
}

public int[] _shape_tuple()
@@ -177,7 +177,9 @@ namespace Tensorflow
if (_handle == null)
{
var output = _as_tf_output();
int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, tf.Status);
Status status = new();
int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, status);
status.Check(true);
return ndim;
}

@@ -199,7 +201,7 @@ namespace Tensorflow
public TF_Output _as_tf_output()
{
if (!_tf_output.HasValue)
_tf_output = new TF_Output(op, value_index);
_tf_output = new TF_Output(op, _value_index);

return _tf_output.Value;
}


+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensors.cs View File

@@ -56,7 +56,7 @@ namespace Tensorflow
public void Add(Tensor tensor)
=> items.Add(tensor);

public void AddRange(Tensor[] tensors)
public void AddRange(IEnumerable<Tensor> tensors)
=> items.AddRange(tensors);

public void Insert(int index, Tensor tensor)


+ 6
- 5
src/TensorFlowNET.Core/Training/Saving/SavedModel/WrapperFunction.cs View File

@@ -12,11 +12,12 @@ namespace Tensorflow.Training.Saving.SavedModel
{
public WrapperFunction(ConcreteFunction concrete_function): base(concrete_function.func_graph)
{
this.forward_backward = concrete_function.forward_backward;
this.Outputs = concrete_function.Outputs;
this.ReturnType = concrete_function.ReturnType;
this.OutputStructure = concrete_function.OutputStructure;
this.ArgKeywords = concrete_function.ArgKeywords;
throw new NotImplementedException();
//this.forward_backward = concrete_function.forward_backward;
//this.Outputs = concrete_function.Outputs;
//this.ReturnType = concrete_function.ReturnType;
//this.OutputStructure = concrete_function.OutputStructure;
//this.ArgKeywords = concrete_function.ArgKeywords;
}
}
}

+ 68
- 27
src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs View File

@@ -30,6 +30,31 @@ namespace Tensorflow.Training.Saving.SavedModel
{
var function_spec = _deserialize_function_spec_as_nonmethod(saved_function.FunctionSpec);

Tensor[] restored_function_body(Tensor[] inputs)
{
if(saved_function.ConcreteFunctions is null || saved_function.ConcreteFunctions.Count == 0)
{
throw new ValueError("Found zero restored functions for caller function.");
}
foreach(var function_name in saved_function.ConcreteFunctions)
{
var function = concrete_functions[function_name];
if(function.CapturedInputs.Any(x => x is null))
{
throw new ValueError("Looks like you are trying to run a loaded " +
"non-Keras model that was trained using tf.distribute.experimental.ParameterServerStrategy " +
"with variable partitioning, which is not currently supported. Try using Keras to define your model " +
"if possible.");
}
if(_concrete_function_callable_with(function, inputs, false))
{
return _call_concrete_function(function, inputs);
}
}
throw new ValueError("Unexpected runtime behavior, please submit an issue to " +
"https://github.com/SciSharp/TensorFlow.NET/issues");
}

List<ConcreteFunction> concrete_function_objects = new();
foreach(var concrete_function_name in saved_function.ConcreteFunctions)
{
@@ -40,17 +65,10 @@ namespace Tensorflow.Training.Saving.SavedModel
cf._set_function_spec(function_spec);
}

foreach(var function_name in saved_function.ConcreteFunctions)
{
var function = concrete_functions[function_name];
if(_concrete_function_callable_with(function, null, false))
{
return new RestoredFunction(null, function, "function_from_deserialization");
}
}
return new RestoredFunction(x => x, new ConcreteFunction(x => x, TF_DataType.TF_FLOAT), "function_return_itself");
//throw new ValueError("Unexpected runtime behavior, please submit an issue to " +
// "https://github.com/SciSharp/TensorFlow.NET/issues");
var restored_function = new RestoredFunction(restored_function_body, nameof(restored_function_body),
function_spec, concrete_function_objects);

return restored_function;
}

public static Dictionary<string, ConcreteFunction> load_function_def_library(FunctionDefLibrary library,
@@ -102,15 +120,17 @@ namespace Tensorflow.Training.Saving.SavedModel
{
var orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, new_gradient_op_types);

if(saved_object_graph is not null && saved_object_graph.ConcreteFunctions.ContainsKey(orig_name))
object structured_input_signature = null;
object structured_outputs = null;
if (saved_object_graph is not null && saved_object_graph.ConcreteFunctions.ContainsKey(orig_name))
{
// TODO(Rinne): implement it.
//var proto = saved_object_graph.ConcreteFunctions[orig_name];
//throw new NotImplementedException();
var proto = saved_object_graph.ConcreteFunctions[orig_name];
structured_input_signature = nested_structure_coder.decode_proto(proto.CanonicalizedInputSignature);
structured_outputs = nested_structure_coder.decode_proto(proto.OutputSignature);
}

graph.as_default();
var func_graph = function_def_lib.function_def_to_graph(fdef, null, null);
var func_graph = function_def_lib.function_def_to_graph(fdef, structured_input_signature, structured_outputs);
graph.Exit();

_restore_gradient_functions(func_graph, renamed_functions, loaded_gradients);
@@ -124,7 +144,7 @@ namespace Tensorflow.Training.Saving.SavedModel
{
fdef.Attr.Remove("_input_shapes");
}
var func = new ConcreteFunction(func_graph, fdef.Attr.ToDictionary(x => x.Key, x => x.Value.S.ToString()));
var func = new ConcreteFunction(func_graph, fdef.Attr.ToDictionary(x => x.Key, x => x.Value));
if(wrapper_function is not null)
{
throw new NotImplementedException();
@@ -142,8 +162,7 @@ namespace Tensorflow.Training.Saving.SavedModel
{
var gradient_op_type = gradients_to_register[orig_name];
loaded_gradients[gradient_op_type] = func;
// TODO(Rinne): deal with gradient registry.
//new RegisteredGradient() { RegisteredOpType = gradient_op_type }.
ops.RegisterGradientFunction(gradient_op_type, _gen_gradient_func(func));
}
}
return functions;
@@ -203,6 +222,16 @@ namespace Tensorflow.Training.Saving.SavedModel
}
}

private static Func<Operation, Tensor[], Tensor[]> _gen_gradient_func(ConcreteFunction func)
{
return (unused_op, result_grads) =>
{
result_grads = zip(result_grads, func.func_graph.Inputs)
.Select((item) => item.Item1 is null ? default_gradient.zeros_like(item.Item2) : item.Item1).ToArray();
return func.CallFlat(result_grads, func.CapturedInputs);
};
}

private static void _restore_gradient_functions(FuncGraph func_graph, Dictionary<string, ConcreteFunction> renamed_functions, Dictionary<string, ConcreteFunction> loaded_gradients)
{
foreach(var op in func_graph.get_operations())
@@ -210,14 +239,14 @@ namespace Tensorflow.Training.Saving.SavedModel
if(op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall")
{
var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name];
// TODO(Rinne): deal with `op._gradient_function`.
op.op._gradient_function = function._get_gradient_function();
}
string gradient_op_type = null;
try
{
gradient_op_type = op.op.get_attr("_gradient_op_type") as string;
}
catch(Exception e)
catch(InvalidArgumentError)
{
continue;
}
@@ -389,7 +418,7 @@ namespace Tensorflow.Training.Saving.SavedModel
concrete_function.ArgKeywords = saved_bare_concrete_function.ArgumentKeywords.ToList();
concrete_function.NumPositionArgs = saved_bare_concrete_function.AllowedPositionalArguments;

var function_spec = _deserialize_function_spec_as_nonmethod(saved_bare_concrete_function.FunctionSpec);
//var function_spec = _deserialize_function_spec_as_nonmethod(saved_bare_concrete_function.FunctionSpec);
// TODO(Rinne): set the functiona spec.
concrete_function.AddTograph();
return concrete_function;
@@ -413,19 +442,31 @@ namespace Tensorflow.Training.Saving.SavedModel
return function.CallFlat(inputs, function.CapturedInputs);
}

private static bool _concrete_function_callable_with(ConcreteFunction function, Tensors inputs, bool allow_conversion)
private static bool _concrete_function_callable_with(ConcreteFunction function, Tensor[] inputs, bool allow_conversion)
{
// TODO(Rinne): revise it.
return true;
return function.CapturedInputs.Length + inputs.Length == function.Inputs.Length;
//var expected_inputs = function.func_graph.Inputs;
//foreach(var (arg, expected) in zip(inputs, expected_inputs))
//{
// if(arg.Id != expected.Id)
// {
// return false;
// }
//}
//return true;
}
}

public class RestoredFunction : Function
{
public RestoredFunction(Func<Tensors, Tensors> function, ConcreteFunction concrete_function,
string name, bool auto_graph = true): base(function, name, auto_graph)
IEnumerable<ConcreteFunction> _concrete_functions;
FunctionSpec _function_spec;
public RestoredFunction(Func<Tensor[], Tensor[]> function, string name, FunctionSpec function_spec,
IEnumerable<ConcreteFunction> concrete_functions): base(function, name, auto_graph: false)
{
_concrete_variable_creation_fn = concrete_function;
_concrete_functions = concrete_functions;
_function_spec = function_spec;
}

protected override bool _run_functions_eagerly()


+ 1
- 1
src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs View File

@@ -102,6 +102,6 @@ public class SignatureMap: Trackable
return new Dictionary<string, Trackable>();
}

return _signatures.TakeWhile(x => x.Value is Function or ConcreteFunction).ToDictionary(x => x.Key, x => x.Value);
return _signatures.Where(x => x.Value is Function or ConcreteFunction).ToDictionary(x => x.Key, x => x.Value);
}
}

+ 3
- 3
src/TensorFlowNET.Core/Training/data_structures.cs View File

@@ -132,8 +132,8 @@ namespace Tensorflow.Training
{
get
{
var trainable_extra_variables = _self_extra_variables.TakeWhile(x => x.Trainable).ToList();
var non_trainable_extra_variables = _self_extra_variables.TakeWhile(x => !x.Trainable).ToList();
var trainable_extra_variables = _self_extra_variables.Where(x => x.Trainable).ToList();
var non_trainable_extra_variables = _self_extra_variables.Where(x => !x.Trainable).ToList();
List<IVariableV1> non_trainable_variables = new();
foreach(var obj in Values)
{
@@ -576,7 +576,7 @@ namespace Tensorflow.Training

if(save_type == SaveType.SAVEDMODEL)
{
children = children.Concat(this.TakeWhile(x => x is Function or ConcreteFunction).Select((x, idx) => new KeyValuePair<string, Trackable>(idx.ToString(), x))).ToDictionary(x => x.Key, x => x.Value);
children = children.Concat(this.Where(x => x is Function or ConcreteFunction).Select((x, idx) => new KeyValuePair<string, Trackable>(idx.ToString(), x))).ToDictionary(x => x.Key, x => x.Value);
}

return children;


+ 24
- 0
src/TensorFlowNET.Core/Util/ProtoUtils.cs View File

@@ -0,0 +1,24 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Util
{
internal static class ProtoUtils
{
public static object GetSingleAttrValue(AttrValue value, AttrValue.ValueOneofCase valueCase)
{
return valueCase switch
{
AttrValue.ValueOneofCase.S => value.S,
AttrValue.ValueOneofCase.I => value.I,
AttrValue.ValueOneofCase.F => value.F,
AttrValue.ValueOneofCase.B => value.B,
AttrValue.ValueOneofCase.Type => value.Type,
AttrValue.ValueOneofCase.Shape => value.Shape,
AttrValue.ValueOneofCase.Tensor => value.Tensor,
AttrValue.ValueOneofCase.Func => value.Func,
};
}
}
}

+ 3
- 3
src/TensorFlowNET.Core/Util/function_utils.cs View File

@@ -7,15 +7,15 @@ namespace Tensorflow.Util
{
internal static class function_utils
{
private static string _rewriter_config_optimizer_disabled;
public static string get_disabled_rewriter_config()
private static ByteString _rewriter_config_optimizer_disabled;
public static ByteString get_disabled_rewriter_config()
{
if(_rewriter_config_optimizer_disabled is null)
{
var config = new ConfigProto();
var rewriter_config = config.GraphOptions.RewriteOptions;
rewriter_config.DisableMetaOptimizer = true;
_rewriter_config_optimizer_disabled = config.ToString();
_rewriter_config_optimizer_disabled = config.ToByteString();
}
return _rewriter_config_optimizer_disabled;
}


+ 24
- 4
src/TensorFlowNET.Core/Util/nest.py.cs View File

@@ -137,10 +137,12 @@ namespace Tensorflow.Util
switch (instance)
{
case Hashtable hash:
var result = new Hashtable();
foreach ((object key, object value) in zip<object, object>(_sorted(hash), args))
result[key] = value;
return result;
{
var result = new Hashtable();
foreach ((object key, object value) in zip<object, object>(_sorted(hash), args))
result[key] = value;
return result;
}
}
}
//else if( _is_namedtuple(instance) || _is_attrs(instance))
@@ -221,6 +223,16 @@ namespace Tensorflow.Util
return list;
}

public static List<T> flatten<T>(IEnumerable<T> structure)
{
var list = new List<T>();
foreach(var item in structure)
{
_flatten_recursive(item, list);
}
return list;
}

public static object[] flatten2(ICanBeFlattened structure)
=> structure.Flatten();

@@ -527,6 +539,14 @@ namespace Tensorflow.Util
return pack_sequence_as(structure, mapped_flat_structure) as T2;
}

public static IEnumerable<T2> map_structure<T1, T2>(Func<T1, T2> func, IEnumerable<T1> structure) where T2 : class
{
var flat_structure = flatten(structure);
var mapped_flat_structure = flat_structure.Select(func).Select(x => (object)x);

return pack_sequence_as(structure, mapped_flat_structure) as IEnumerable<T2>;
}

/// <summary>
/// Same as map_structure, but with only one structure (no combining of multiple structures)
/// </summary>


+ 33
- 0
src/TensorFlowNET.Core/Util/variable_utils.cs View File

@@ -0,0 +1,33 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Framework;

namespace Tensorflow.Util
{
internal static class variable_utils
{
public static Tensor[] convert_variables_to_tensors(object[] values)
{
return values.Select(x =>
{
if (resource_variable_ops.is_resource_variable(x))
{
return ops.convert_to_tensor(x);
}
else if (x is CompositeTensor)
{
throw new NotImplementedException("The composite tensor has not been fully supported.");
}
else if(x is Tensor tensor)
{
return tensor;
}
else
{
throw new TypeError("Currently the output of function to be traced must be `Tensor`.");
}
}).ToArray();
}
}
}

+ 5
- 3
src/TensorFlowNET.Core/ops.cs View File

@@ -248,7 +248,7 @@ namespace Tensorflow
foreach (var attr in node_def.Attr)
{
var bytes = attr.Value.ToByteArray();
c_api.TF_SetAttrValueProto(op_desc, attr.Key, bytes, proto_len: bytes.Length, status: status);
c_api.TF_SetAttrValueProto(op_desc, attr.Key, bytes, proto_len: (ulong)bytes.Length, status: status);
status.Check(true);
}

@@ -575,10 +575,12 @@ namespace Tensorflow

public static HandleData get_resource_handle_data(Tensor graph_op)
{
throw new NotImplementedException();
// This implementation hasn't been checked for some reasons.
// If it throws an exception in the future, please check it.
var handle_data = c_api.GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output());
return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data)));

//var handle_data = c_api.GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output());
//return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data)));
}

public static void dismantle_graph(Graph graph)


+ 4
- 0
src/TensorFlowNET.Keras/Engine/Model.Train.cs View File

@@ -35,6 +35,10 @@ namespace Tensorflow.Keras.Engine
{
(x, y) = data_handler.DataAdapter.Expand1d(x, y);
using var tape = tf.GradientTape();
//foreach (var variable in TrainableVariables)
//{
// tape.watch(variable.Handle);
//}
var y_pred = Apply(x, training: true);
var loss = compiled_loss.Call(y, y_pred);



+ 2
- 2
src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs View File

@@ -84,8 +84,8 @@ namespace Tensorflow.Keras.Layers
inputs.Insert(index, value);
}

var (c_op, _) = ops._create_c_op(graph, node_def, inputs.ToArray(), new Operation[0]);
var op = graph._create_op_from_tf_operation(c_op);
var (c_op, op_desc) = ops._create_c_op(graph, node_def, inputs.ToArray(), new Operation[0]);
var op = graph._create_op_from_tf_operation(c_op, desc: op_desc);
op._control_flow_post_processing();

// Record the gradient because custom-made ops don't go through the


+ 3
- 3
src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs View File

@@ -51,9 +51,9 @@ namespace Tensorflow.Keras.Saving.SavedModel
_all_functions = new HashSet<string>(objects_and_functions.Item2);
}

public IDictionary<string, Trackable> Functions => _function_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!);
public IDictionary<string, Trackable> Functions => _function_dict.Where(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!);

public IDictionary<string, Trackable> CheckpointableObjects => _object_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!);
public IDictionary<string, Trackable> CheckpointableObjects => _object_dict.Where(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!);

/// <summary>
/// Returns functions to attach to the root object during serialization.
@@ -82,7 +82,7 @@ namespace Tensorflow.Keras.Saving.SavedModel
{
get
{
var objects = CheckpointableObjects.TakeWhile( x=> _all_checkpointable_objects.Contains(x.Key)).ToDictionary(x => x.Key, x => x.Value);
var objects = CheckpointableObjects.Where( x=> _all_checkpointable_objects.Contains(x.Key)).ToDictionary(x => x.Key, x => x.Value);
objects[Constants.KERAS_ATTR] = _keras_trackable;
return objects;
}


BIN
test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/fingerprint.pb View File


+ 6
- 0
test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/keras_metadata.pb View File

@@ -0,0 +1,6 @@

¯root"_tf_keras_sequential*Š{"name": "sequential", "trainable": true, "expects_training_arg": true, "dtype": "float32", "batch_input_shape": null, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": false, "class_name": "Sequential", "config": {"name": "sequential", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 784]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}}, {"class_name": "Transformer", "config": {"name": "transformer", "trainable": true, "dtype": "float32", "a": 784, "b": 10}}, {"class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}}]}, "shared_object_id": 3, "input_spec": [{"class_name": "InputSpec", "config": {"dtype": null, "shape": {"class_name": "__tuple__", "items": [null, 784]}, "ndim": 2, "max_ndim": null, "min_ndim": null, "axes": {}}}], "build_input_shape": {"class_name": "TensorShape", "items": [null, 784]}, "is_graph_network": true, "full_save_spec": {"class_name": "__tuple__", "items": [[{"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 784]}, "float32", "input_1"]}], {}]}, "save_spec": {"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 784]}, "float32", "input_1"]}, "keras_version": "2.11.0", "backend": "tensorflow", "model_config": {"class_name": "Sequential", "config": {"name": "sequential", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 784]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}, "shared_object_id": 0}, {"class_name": "Transformer", "config": {"name": "transformer", "trainable": true, "dtype": "float32", "a": 784, "b": 10}, "shared_object_id": 1}, {"class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}, "shared_object_id": 2}]}}, "training_config": {"loss": "sparse_categorical_crossentropy", "metrics": [[{"class_name": "MeanMetricWrapper", "config": {"name": "accuracy", "dtype": "float32", "fn": "categorical_accuracy"}, "shared_object_id": 5}]], "weighted_metrics": null, "loss_weights": null, "optimizer_config": {"class_name": "Custom>Adam", "config": {"name": "Adam", "weight_decay": null, "clipnorm": null, "global_clipnorm": null, "clipvalue": null, "use_ema": false, "ema_momentum": 0.99, "ema_overwrite_frequency": null, "jit_compile": false, "is_legacy_optimizer": false, "learning_rate": 0.0010000000474974513, "beta_1": 0.9, "beta_2": 0.999, "epsilon": 1e-07, "amsgrad": false}}}}2
ÿroot.layer_with_weights-0"_tf_keras_layer*È{"name": "transformer", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Transformer", "config": {"name": "transformer", "trainable": true, "dtype": "float32", "a": 784, "b": 10}, "shared_object_id": 1, "build_input_shape": {"class_name": "TensorShape", "items": [null, 784]}}2
Þ root.layer-1"_tf_keras_layer*´{"name": "softmax", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}, "shared_object_id": 2, "build_input_shape": {"class_name": "TensorShape", "items": [null, 10]}}2
¸9root.keras_api.metrics.0"_tf_keras_metric*�{"class_name": "Mean", "name": "loss", "dtype": "float32", "config": {"name": "loss", "dtype": "float32"}, "shared_object_id": 6}2
ë:root.keras_api.metrics.1"_tf_keras_metric*´{"class_name": "MeanMetricWrapper", "name": "accuracy", "dtype": "float32", "config": {"name": "accuracy", "dtype": "float32", "fn": "categorical_accuracy"}, "shared_object_id": 5}2

BIN
test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/saved_model.pb View File


BIN
test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/variables/variables.data-00000-of-00001 View File


BIN
test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/variables/variables.index View File


+ 17
- 3
test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs View File

@@ -6,7 +6,6 @@ using Tensorflow.Keras.Optimizers;
using Tensorflow.Keras.UnitTest.Helpers;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace TensorFlowNET.Keras.UnitTest.SaveModel;

@@ -62,11 +61,26 @@ public class SequentialModelLoad
[TestMethod]
public void Temp()
{
var model = tf.keras.models.load_model(@"D:\development\tf.net\tf_test\python_func");
var model = tf.keras.models.load_model(@"Assets/python_func_model");
model.summary();

var x = tf.ones((2, 10));
var x = tf.random.uniform((8, 784), -1, 1);
var y = model.Apply(x);
Console.WriteLine(y);

//model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });

//var data_loader = new MnistModelLoader();
//var num_epochs = 1;
//var batch_size = 8;

//var dataset = data_loader.LoadAsync(new ModelLoadSetting
//{
// TrainDir = "mnist",
// OneHot = false,
// ValidationSize = 58000,
//}).Result;

//model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
}
}

+ 16
- 0
test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj View File

@@ -49,6 +49,22 @@
<None Update="Assets\simple_model_from_auto_compile\bias0.npy">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>

<None Update="Assets\python_func_model\fingerprint.pb">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="Assets\python_func_model\keras_metadata.pb">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="Assets\python_func_model\saved_model.pb">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="Assets\python_func_model\variables\variables.data-00000-of-00001">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="Assets\python_func_model\variables\variables.index">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
</ItemGroup>

</Project>

+ 1
- 1
test/TensorFlowNET.Native.UnitTest/Functions/FunctionTest.cs View File

@@ -413,7 +413,7 @@ namespace Tensorflow.Native.UnitTest
ASSERT_EQ(TF_OK, s_.Code, s_.Message);
ASSERT_NE(func_, IntPtr.Zero);
ASSERT_EQ(func_name_, c_api.StringPiece(c_api.TF_FunctionName(func_)));
c_api.TF_GraphCopyFunction(host_graph_, func_, IntPtr.Zero, s_);
c_api.TF_GraphCopyFunction(host_graph_, func_, new SafeFuncGraphHandle(IntPtr.Zero), s_);
ASSERT_EQ(TF_OK, s_.Code, s_.Message);
}



Loading…
Cancel
Save