Browse Source

Partially support the analysis of loaded functions.

tags/v0.100.5-BERT-load
AsakusaRinne 2 years ago
parent
commit
acae9b3e39
42 changed files with 782 additions and 284 deletions
  1. +16
    -2
      TensorFlow.NET.sln
  2. +13
    -0
      Tensorflow.Common/Extensions/OneofExtension.cs
  3. +11
    -0
      Tensorflow.Common/Tensorflow.Common.csproj
  4. +11
    -9
      src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs
  5. +4
    -3
      src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs
  6. +5
    -4
      src/TensorFlowNET.Core/Checkpoint/checkpoint.cs
  7. +21
    -115
      src/TensorFlowNET.Core/Checkpoint/functional_saver.cs
  8. +17
    -16
      src/TensorFlowNET.Core/Checkpoint/restore.cs
  9. +13
    -0
      src/TensorFlowNET.Core/Eager/forwardprop_util.cs
  10. +66
    -4
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  11. +50
    -3
      src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs
  12. +35
    -5
      src/TensorFlowNET.Core/Functions/Function.cs
  13. +9
    -7
      src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs
  14. +1
    -0
      src/TensorFlowNET.Core/Functions/function_saved_model_utils.cs
  15. +20
    -7
      src/TensorFlowNET.Core/Functions/monomorphic_function.cs
  16. +5
    -0
      src/TensorFlowNET.Core/Gradients/gradients_util.cs
  17. +1
    -0
      src/TensorFlowNET.Core/Graphs/Graph.cs
  18. +4
    -0
      src/TensorFlowNET.Core/Operations/c_api.ops.cs
  19. +70
    -0
      src/TensorFlowNET.Core/Operations/functional_ops.cs
  20. +83
    -0
      src/TensorFlowNET.Core/Operations/gen_functional_ops.cs
  21. +20
    -5
      src/TensorFlowNET.Core/Operations/handle_data_util.cs
  22. +17
    -1
      src/TensorFlowNET.Core/Operations/resource_variable_ops.cs
  23. +1
    -1
      src/TensorFlowNET.Core/Protobuf/CppShapeInference.cs
  24. +16
    -9
      src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs
  25. +5
    -0
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  26. +1
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  27. +4
    -3
      src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs
  28. +50
    -3
      src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs
  29. +22
    -14
      src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs
  30. +18
    -18
      src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs
  31. +9
    -8
      src/TensorFlowNET.Core/Training/Trackable.cs
  32. +23
    -0
      src/TensorFlowNET.Core/Util/function_utils.cs
  33. +4
    -3
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  34. +9
    -1
      src/TensorFlowNET.Core/Variables/ResourceVariable.cs
  35. +5
    -4
      src/TensorFlowNET.Core/Variables/UninitializedVariable.cs
  36. +6
    -2
      src/TensorFlowNET.Core/ops.cs
  37. +2
    -2
      src/TensorFlowNET.Keras/Engine/Layer.cs
  38. +58
    -21
      src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs
  39. +9
    -4
      src/TensorFlowNET.Keras/Saving/SavedModel/ReviveUtils.cs
  40. +15
    -0
      src/TensorFlowNET.Keras/Saving/SavedModel/RevivedInputLayer.cs
  41. +27
    -0
      src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs
  42. +6
    -10
      src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs

+ 16
- 2
TensorFlow.NET.sln View File

@@ -1,7 +1,7 @@
 
Microsoft Visual Studio Solution File, Format Version 12.00 Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio Version 16
VisualStudioVersion = 16.0.31624.102
# Visual Studio Version 17
VisualStudioVersion = 17.4.33213.308
MinimumVisualStudioVersion = 10.0.40219.1 MinimumVisualStudioVersion = 10.0.40219.1
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Binding", "src\TensorFlowNET.Core\Tensorflow.Binding.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}" Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Binding", "src\TensorFlowNET.Core\Tensorflow.Binding.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}"
EndProject EndProject
@@ -23,6 +23,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras.UnitTest",
EndProject EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Graph.UnitTest", "test\TensorFlowNET.Graph.UnitTest\TensorFlowNET.Graph.UnitTest.csproj", "{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}" Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Graph.UnitTest", "test\TensorFlowNET.Graph.UnitTest\TensorFlowNET.Graph.UnitTest.csproj", "{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}"
EndProject EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Tensorflow.Common", "Tensorflow.Common\Tensorflow.Common.csproj", "{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}"
EndProject
Global Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU Debug|Any CPU = Debug|Any CPU
@@ -153,6 +155,18 @@ Global
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.Build.0 = Release|x64 {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.Build.0 = Release|x64
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.ActiveCfg = Release|Any CPU {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.ActiveCfg = Release|Any CPU
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.Build.0 = Release|Any CPU {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.Build.0 = Release|Any CPU
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|Any CPU.Build.0 = Debug|Any CPU
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|x64.ActiveCfg = Debug|Any CPU
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|x64.Build.0 = Debug|Any CPU
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|x86.ActiveCfg = Debug|Any CPU
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|x86.Build.0 = Debug|Any CPU
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|Any CPU.ActiveCfg = Release|Any CPU
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|Any CPU.Build.0 = Release|Any CPU
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|x64.ActiveCfg = Release|Any CPU
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|x64.Build.0 = Release|Any CPU
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|x86.ActiveCfg = Release|Any CPU
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|x86.Build.0 = Release|Any CPU
EndGlobalSection EndGlobalSection
GlobalSection(SolutionProperties) = preSolution GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE HideSolutionNode = FALSE


+ 13
- 0
Tensorflow.Common/Extensions/OneofExtension.cs View File

@@ -0,0 +1,13 @@
using OneOf;
using System;

namespace Tensorflow.Common.Extensions
{
public static class OneofExtension
{
public static bool IsTypeOrDeriveFrom<T>(this IOneOf src)
{
return src.Value is T;
}
}
}

+ 11
- 0
Tensorflow.Common/Tensorflow.Common.csproj View File

@@ -0,0 +1,11 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="OneOf" Version="3.0.223" />
</ItemGroup>

</Project>

+ 11
- 9
src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs View File

@@ -1,10 +1,12 @@
using System;
using OneOf;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics; using System.Diagnostics;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using Tensorflow.Train; using Tensorflow.Train;
using Tensorflow.Training; using Tensorflow.Training;
using Tensorflow.Common.Extensions;
using pbc = global::Google.Protobuf.Collections; using pbc = global::Google.Protobuf.Collections;


namespace Tensorflow.Checkpoint namespace Tensorflow.Checkpoint
@@ -28,7 +30,7 @@ namespace Tensorflow.Checkpoint
); );
public static class SaveUtil public static class SaveUtil
{ {
public static (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
public static (IDictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null) serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null)
{ {
var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map);
@@ -117,16 +119,16 @@ namespace Tensorflow.Checkpoint
/// <param name="call_with_mapped_captures"></param> /// <param name="call_with_mapped_captures"></param>
/// <param name="cache"></param> /// <param name="cache"></param>
/// <param name="object_graph_proto"></param> /// <param name="object_graph_proto"></param>
private static IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids,
private static IDictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids,
bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto)
{ {
Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new();
Dictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new();
foreach(var td in tensor_trackables) foreach(var td in tensor_trackables)
{ {
// TODO: deal with cache. // TODO: deal with cache.
var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? "";
Trackable trackable = null; Trackable trackable = null;
IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict;
IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> tensor_dict;
if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0)
{ {
(trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); (trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto);
@@ -148,12 +150,12 @@ namespace Tensorflow.Checkpoint
return serialized_tensors; return serialized_tensors;
} }


private static IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
private static IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
{ {
var trackable = trackable_data.object_to_save; var trackable = trackable_data.object_to_save;


// TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. // TODO: complete it. Note that actually `call_with_mapped_captures` is of function type.
IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> ret_tensor_dict;
IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> ret_tensor_dict;
if (call_with_mapped_captures) if (call_with_mapped_captures)
{ {
throw new NotImplementedException(); throw new NotImplementedException();
@@ -164,7 +166,7 @@ namespace Tensorflow.Checkpoint
} }


// TODO: deal with the type `SaveSpce` (currently it will never be it). // TODO: deal with the type `SaveSpce` (currently it will never be it).
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new();
Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> tensor_dict = new();
foreach(var pair in ret_tensor_dict) foreach(var pair in ret_tensor_dict)
{ {
var local_name = TrackableUtils.escape_local_name(pair.Key); var local_name = TrackableUtils.escape_local_name(pair.Key);
@@ -200,7 +202,7 @@ namespace Tensorflow.Checkpoint
/// <param name="call_with_mapped_captures"></param> /// <param name="call_with_mapped_captures"></param>
/// <param name="object_graph_proto"></param> /// <param name="object_graph_proto"></param>
/// <returns></returns> /// <returns></returns>
private static (Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids,
private static (Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids,
bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
{ {
Dictionary<Trackable, string> object_names = new(); Dictionary<Trackable, string> object_names = new();


+ 4
- 3
src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs View File

@@ -8,6 +8,7 @@ using Tensorflow.Training;
using pbc = global::Google.Protobuf.Collections; using pbc = global::Google.Protobuf.Collections;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using Google.Protobuf; using Google.Protobuf;
using OneOf;


namespace Tensorflow.Checkpoint; namespace Tensorflow.Checkpoint;


@@ -179,13 +180,13 @@ public static class SaveUtilV1


// TODO: tensorflow python has a process with callable `saveable_factory`. // TODO: tensorflow python has a process with callable `saveable_factory`.
List<MySaveableObject> saveables = new(); List<MySaveableObject> saveables = new();
if (maybe_saveable.TryGet<MySaveableObject>(out var s))
if (maybe_saveable.TryPickT1(out var s, out var variable))
{ {
saveables.Add(s); saveables.Add(s);
} }
else else
{ {
saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValue<BaseResourceVariable>() as Trackable, key));
saveables.AddRange(saveable_object_util.saveable_objects_for_op(variable as Trackable, key));
} }


foreach (var saveable in saveables) foreach (var saveable in saveables)
@@ -217,7 +218,7 @@ public static class SaveUtilV1


public record class CheckpointFactoryData public record class CheckpointFactoryData
( (
Func<string, Maybe<BaseResourceVariable, MySaveableObject>> factory,
Func<string, OneOf<BaseResourceVariable, MySaveableObject>> factory,
string name, string name,
string checkpoint_key string checkpoint_key
); );

+ 5
- 4
src/TensorFlowNET.Core/Checkpoint/checkpoint.cs View File

@@ -12,6 +12,7 @@ using static Tensorflow.Binding;
using Tensorflow.Operations; using Tensorflow.Operations;
using Newtonsoft.Json; using Newtonsoft.Json;
using Tensorflow.Training; using Tensorflow.Training;
using OneOf;


namespace Tensorflow.Checkpoint; namespace Tensorflow.Checkpoint;


@@ -49,7 +50,7 @@ public class TrackableSaver
} }


private (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
private (IDictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
gather_serialized_tensors(Tensor? object_graph_tensor = null) gather_serialized_tensors(Tensor? object_graph_tensor = null)
{ {
var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache);
@@ -68,7 +69,7 @@ public class TrackableSaver
Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
if (!serialized_tensors.ContainsKey(Trackable.None)) if (!serialized_tensors.ContainsKey(Trackable.None))
{ {
serialized_tensors[Trackable.None] = new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>();
serialized_tensors[Trackable.None] = new Dictionary<string, OneOf.OneOf<Tensor, IDictionary<string, Tensor>>>();
} }
serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor; serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor;
return (serialized_tensors, feed_additions, registered_savers, graph_proto); return (serialized_tensors, feed_additions, registered_savers, graph_proto);
@@ -400,7 +401,7 @@ public class CheckpointRestoreCoordinator
// skip the callback. // skip the callback.
} }


public List<Operation> restore_saveables(Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> tensor_saveables, List<CheckpointPosition> positions, object? registered_savers = null)
public List<Operation> restore_saveables(Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> tensor_saveables, List<CheckpointPosition> positions, object? registered_savers = null)
{ {
List<Operation> restore_ops = new(); List<Operation> restore_ops = new();
foreach(var position in positions) foreach(var position in positions)
@@ -412,7 +413,7 @@ public class CheckpointRestoreCoordinator
Dictionary<string, BaseResourceVariable> variable_dict = new(); Dictionary<string, BaseResourceVariable> variable_dict = new();
foreach(var item in tensor_saveables) foreach(var item in tensor_saveables)
{ {
if(item.Value.TryGet<BaseResourceVariable>(out var variable))
if(item.Value.TryPickT0(out var variable, out var _))
{ {
variable_dict[item.Key] = variable; variable_dict[item.Key] = variable;
} }


+ 21
- 115
src/TensorFlowNET.Core/Checkpoint/functional_saver.cs View File

@@ -15,106 +15,14 @@ using Tensorflow.Graphs;
using System.Xml.Linq; using System.Xml.Linq;
using System.Diagnostics; using System.Diagnostics;
using RestoreFunc = System.Func<object, object>; using RestoreFunc = System.Func<object, object>;
using OneOf;


namespace Tensorflow.Checkpoint namespace Tensorflow.Checkpoint
{ {
public class Maybe<TA, TB>
{
private TA? _valueA = default(TA);
private TB? _valueB = default(TB);
private Type _type;
private bool _assignedTA;
public Maybe(TA value)
{
_valueA = value;
_type= typeof(TA);
_assignedTA = true;
}
public Maybe(TB value)
{
_valueB = value;
_type = typeof(TB);
_assignedTA = false;
}

public Type DataType => _type;

/// <summary>
/// Try to get the type T member of this instance. It returns true when TA or TB derive from T and is correspondingly assigned.
/// It returns
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="res"></param>
/// <returns></returns>
public bool TryGet<T>(out T? res)
{
if(_valueA is T && _valueB is not T)
{
res = (T)(object)_valueA;
return _assignedTA;
}
else if(_valueA is not T && _valueB is T)
{
res = (T)(object)_valueB;
return !_assignedTA;
}
res = default(T);
return false;
}

public bool IsTypeOrDeriveFrom<T>()
{
if (_valueA is T && _valueB is not T)
{
return _assignedTA;
}
else if (_valueA is not T && _valueB is T)
{
return !_assignedTA;
}
else if (_valueA is T && _valueB is T)
{
return true;
}
else
{
return false;
}
}

public T GetValue<T>()
{
if (_valueA is T && _valueB is not T)
{
return (T)(object)_valueA;
}
else if (_valueA is not T && _valueB is T)
{
return (T)(object)_valueB;
}
else if (_valueA is T && _valueB is T)
{
throw new TypeError("The type is vague, this is always because TA and TB both derive from T.");
}
else
{
throw new TypeError($"Expected {typeof(TA)} or {typeof(TB)}, but got typeof{typeof(T)}.");
}
}

public static implicit operator Maybe<TA, TB>(TA a)
{
return new Maybe<TA, TB>(a);
}
public static implicit operator Maybe<TA, TB>(TB b)
{
return new Maybe<TA, TB>(b);
}
}
internal class SingleDeviceSaver internal class SingleDeviceSaver
{ {
private IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> _tensor_slice_dict;
public SingleDeviceSaver(IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> tensor_slice_dict)
private IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> _tensor_slice_dict;
public SingleDeviceSaver(IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_slice_dict)
{ {
_tensor_slice_dict = tensor_slice_dict; _tensor_slice_dict = tensor_slice_dict;
} }
@@ -122,15 +30,15 @@ namespace Tensorflow.Checkpoint
{ {
_tensor_slice_dict = tensor_slice_dict.ToDictionary( _tensor_slice_dict = tensor_slice_dict.ToDictionary(
x => x.Key, x => x.Value.ToDictionary( x => x.Key, x => x.Value.ToDictionary(
y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value))
as IDictionary<string, Maybe<Tensor, SaveSpec>>);
y => y.Key, y => OneOf<Tensor, SaveSpec>.FromT0(y.Value))
as IDictionary<string, OneOf<Tensor, SaveSpec>>);
} }
public SingleDeviceSaver(IDictionary<string, IDictionary<string, SaveSpec>> tensor_slice_dict) public SingleDeviceSaver(IDictionary<string, IDictionary<string, SaveSpec>> tensor_slice_dict)
{ {
_tensor_slice_dict = tensor_slice_dict.ToDictionary( _tensor_slice_dict = tensor_slice_dict.ToDictionary(
x => x.Key, x => x.Value.ToDictionary( x => x.Key, x => x.Value.ToDictionary(
y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value))
as IDictionary<string, Maybe<Tensor, SaveSpec>>);
y => y.Key, y => OneOf<Tensor, SaveSpec>.FromT1(y.Value))
as IDictionary<string, OneOf<Tensor, SaveSpec>>);
} }
public Operation? save(Tensor file_prefix, CheckpointOptions? options = null) public Operation? save(Tensor file_prefix, CheckpointOptions? options = null)
{ {
@@ -149,7 +57,7 @@ namespace Tensorflow.Checkpoint
{ {
var slice_spec = slice.Key; var slice_spec = slice.Key;
var maybe_tensor = slice.Value; var maybe_tensor = slice.Value;
if(maybe_tensor.TryGet<SaveSpec>(out var spec))
if(maybe_tensor.TryPickT1(out var spec, out var tensor))
{ {
var tensor_value = spec.tensor; var tensor_value = spec.tensor;
if (tensor_value is not null) if (tensor_value is not null)
@@ -161,7 +69,6 @@ namespace Tensorflow.Checkpoint
} }
else else
{ {
var tensor = maybe_tensor.GetValue<Tensor>();
tensor_names.Add(checkpoint_key); tensor_names.Add(checkpoint_key);
tensors.Add(tensor); tensors.Add(tensor);
slice_specs.Add(slice_spec); slice_specs.Add(slice_spec);
@@ -193,7 +100,7 @@ namespace Tensorflow.Checkpoint
var slice_spec = slice.Key; var slice_spec = slice.Key;
var maybe_tensor = slice.Value; var maybe_tensor = slice.Value;
// TODO: deal with other types. Currently only `SaveSpec` is allowed. // TODO: deal with other types. Currently only `SaveSpec` is allowed.
if(maybe_tensor.TryGet<SaveSpec>(out var spec))
if(maybe_tensor.TryPickT1(out var spec, out var tensor))
{ {
tensor_dtypes.Add(spec.dtype); tensor_dtypes.Add(spec.dtype);
slice_specs.Add(spec.slice_spec); slice_specs.Add(spec.slice_spec);
@@ -201,7 +108,6 @@ namespace Tensorflow.Checkpoint
} }
else else
{ {
var tensor = maybe_tensor.GetValue<Tensor>();
tensor_dtypes.Add(tensor.dtype); tensor_dtypes.Add(tensor.dtype);
slice_specs.Add(slice_spec); slice_specs.Add(slice_spec);
tensor_names.Add(checkpoint_key); tensor_names.Add(checkpoint_key);
@@ -254,7 +160,7 @@ namespace Tensorflow.Checkpoint
/// <param name="serialized_tensors"> A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. </param> /// <param name="serialized_tensors"> A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. </param>
/// <param name="registered_savers"></param> /// <param name="registered_savers"></param>
/// <param name="call_with_mapped_capture"></param> /// <param name="call_with_mapped_capture"></param>
public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors,
public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> serialized_tensors,
IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false) IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false)
{ {
_keys_to_restore_fn = new Dictionary<(string, string), RestoreFunc>(); _keys_to_restore_fn = new Dictionary<(string, string), RestoreFunc>();
@@ -274,9 +180,9 @@ namespace Tensorflow.Checkpoint
{ {
restore_fn = new RestoreFunc(x => restore_fn = new RestoreFunc(x =>
{ {
if(x is IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>)
if(x is IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>)
{ {
return obj._restore_from_tensors(x as IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>);
return obj._restore_from_tensors(x as IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>);
} }
throw new TypeError($"Expected `IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>` as input, got{x.GetType()}."); throw new TypeError($"Expected `IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>` as input, got{x.GetType()}.");
}); });
@@ -286,14 +192,14 @@ namespace Tensorflow.Checkpoint
{ {
var checkpoint_key = item.Key; var checkpoint_key = item.Key;
IDictionary<string, Tensor> spec_to_tensor; IDictionary<string, Tensor> spec_to_tensor;
if(item.Value.TryGet<Tensor>(out var t))
if(item.Value.TryPickT0(out var t, out var dic))
{ {
spec_to_tensor = new Dictionary<string, Tensor>(); spec_to_tensor = new Dictionary<string, Tensor>();
spec_to_tensor[""] = t; spec_to_tensor[""] = t;
} }
else else
{ {
spec_to_tensor = item.Value.GetValue<IDictionary<string, Tensor>>();
spec_to_tensor = dic;
} }


foreach(var spec in spec_to_tensor) foreach(var spec in spec_to_tensor)
@@ -399,7 +305,7 @@ namespace Tensorflow.Checkpoint


IDictionary<string, Operation> restore_func() IDictionary<string, Operation> restore_func()
{ {
Dictionary<RestoreFunc, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> restore_fn_inputs = new();
Dictionary<RestoreFunc, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> restore_fn_inputs = new();
Dictionary<RestoreFunc, int> restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count); Dictionary<RestoreFunc, int> restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count);
Dictionary<string, Operation> restore_ops = new(); Dictionary<string, Operation> restore_ops = new();


@@ -419,29 +325,29 @@ namespace Tensorflow.Checkpoint
var slice_spec = item.Key; var slice_spec = item.Key;
var tensor = item.Value; var tensor = item.Value;
var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)];
var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>());
var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>());
if (!string.IsNullOrEmpty(slice_spec)) if (!string.IsNullOrEmpty(slice_spec))
{ {
if (!internal_dict.ContainsKey(checkpoint_key)) if (!internal_dict.ContainsKey(checkpoint_key))
{ {
Dictionary<string, Tensor> dict = new(); Dictionary<string, Tensor> dict = new();
dict[slice_spec] = tensor; dict[slice_spec] = tensor;
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(dict);
internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT1(dict);
} }
else else
{ {
internal_dict[checkpoint_key].GetValue<IDictionary<string, Tensor>>()[slice_spec] = tensor;
internal_dict[checkpoint_key].AsT1[slice_spec] = tensor;
} }
} }
else else
{ {
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(tensor);
internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT0(tensor);
} }
restore_fn_input_count[restore_fn]--; restore_fn_input_count[restore_fn]--;


if (restore_fn_input_count[restore_fn] == 0) if (restore_fn_input_count[restore_fn] == 0)
{ {
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors = new();
Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> restored_tensors = new();
foreach(var input in restore_fn_inputs[restore_fn]) foreach(var input in restore_fn_inputs[restore_fn])
{ {
restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value;
@@ -519,7 +425,7 @@ namespace Tensorflow.Checkpoint


public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false) public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false)
{ {
Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new();
Dictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new();
foreach (var saveable in saveables) foreach (var saveable in saveables)
{ {
var trackable = new SaveableCompatibilityConverter(saveable, new List<MySaveableObject>() { saveable }); var trackable = new SaveableCompatibilityConverter(saveable, new List<MySaveableObject>() { saveable });


+ 17
- 16
src/TensorFlowNET.Core/Checkpoint/restore.cs View File

@@ -1,4 +1,5 @@
using System;
using OneOf;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics; using System.Diagnostics;
using System.Linq; using System.Linq;
@@ -61,13 +62,13 @@ public class CheckpointPosition
} }
} }


public (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) gather_ops_or_named_saveables()
public (List<Operation>, Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) gather_ops_or_named_saveables()
{ {
// skip the registered_saver // skip the registered_saver


if (ObjectProto.Attributes is null || ObjectProto.Attributes.Count == 0) if (ObjectProto.Attributes is null || ObjectProto.Attributes.Count == 0)
{ {
return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(),
return (new List<Operation>(), new Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>(),
new List<CheckpointPosition>(), null); new List<CheckpointPosition>(), null);
} }


@@ -75,7 +76,7 @@ public class CheckpointPosition


List<Operation> existing_restore_ops; List<Operation> existing_restore_ops;
List<CheckpointPosition> positions = new(); List<CheckpointPosition> positions = new();
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> named_saveables;
Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> named_saveables;
if (saveable_factories.Keys.Count == 1 && saveable_factories.Keys.First() == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) if (saveable_factories.Keys.Count == 1 && saveable_factories.Keys.First() == TrackableUtils.SERIALIZE_TO_TENSORS_NAME)
{ {
(existing_restore_ops, named_saveables) = _create_serialize_to_tensor_saveable(saveable_factories); (existing_restore_ops, named_saveables) = _create_serialize_to_tensor_saveable(saveable_factories);
@@ -109,8 +110,8 @@ public class CheckpointPosition
/// Creates a saveable using the _serialize_to_tensor method. /// Creates a saveable using the _serialize_to_tensor method.
/// </summary> /// </summary>
/// <param name="saveable_factories"></param> /// <param name="saveable_factories"></param>
private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>) _create_serialize_to_tensor_saveable(
IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories)
private (List<Operation>, Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>) _create_serialize_to_tensor_saveable(
IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> saveable_factories)
{ {
string suffix = SaveableCompat.get_saveable_name(this.Trackable); string suffix = SaveableCompat.get_saveable_name(this.Trackable);
suffix = suffix ?? ""; suffix = suffix ?? "";
@@ -124,23 +125,23 @@ public class CheckpointPosition


var saveable = saveable_factories[TrackableUtils.SERIALIZE_TO_TENSORS_NAME](saveable_name); var saveable = saveable_factories[TrackableUtils.SERIALIZE_TO_TENSORS_NAME](saveable_name);
// skip the cache. // skip the cache.
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> dict = new();
Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> dict = new();
dict[saveable_name] = saveable; dict[saveable_name] = saveable;
return (new List<Operation>(), dict); return (new List<Operation>(), dict);
} }


private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>) _create_saveables_by_attribute_name(
IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories)
private (List<Operation>, Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>) _create_saveables_by_attribute_name(
IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> saveable_factories)
{ {
// TODO(Rinne): implement it. // TODO(Rinne): implement it.
if(ObjectProto.Attributes is null) if(ObjectProto.Attributes is null)
{ {
return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>());
return (new List<Operation>(), new Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>());
} }


List<Operation> existing_restore_ops = new(); List<Operation> existing_restore_ops = new();
HashSet<string> created_compat_names = new(); HashSet<string> created_compat_names = new();
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> named_saveables = new();
Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> named_saveables = new();
foreach (var serialized_tensor in ObjectProto.Attributes) foreach (var serialized_tensor in ObjectProto.Attributes)
{ {
Operation existing_op; Operation existing_op;
@@ -172,12 +173,12 @@ public class CheckpointPosition
_checkpoint.UnusedAttributes.SetDefault(_proto_id, new List<string>()).Add(serialized_tensor.Name); _checkpoint.UnusedAttributes.SetDefault(_proto_id, new List<string>()).Add(serialized_tensor.Name);
continue; continue;
} }
named_saveables[serialized_tensor.CheckpointKey] = saveable;
named_saveables[serialized_tensor.CheckpointKey] = saveable.Value;
} }
return (existing_restore_ops, named_saveables); return (existing_restore_ops, named_saveables);
} }


private Maybe<BaseResourceVariable, MySaveableObject> _get_saveable_from_factory(IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories,
private OneOf<BaseResourceVariable, MySaveableObject>? _get_saveable_from_factory(IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> saveable_factories,
TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor serialized_tensor, HashSet<string> created_compat_names) TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor serialized_tensor, HashSet<string> created_compat_names)
{ {
var expected_factory_name = serialized_tensor.Name; var expected_factory_name = serialized_tensor.Name;
@@ -221,7 +222,7 @@ public class CheckpointPosition
Queue<(CheckpointPosition, Trackable)> visit_queue = new(); Queue<(CheckpointPosition, Trackable)> visit_queue = new();
visit_queue.Enqueue((this, this.Trackable)); visit_queue.Enqueue((this, this.Trackable));
List<Operation> restore_ops = new(); List<Operation> restore_ops = new();
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> tensor_saveables = new();
Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> tensor_saveables = new();
List<CheckpointPosition> positions = new(); List<CheckpointPosition> positions = new();


CheckpointPosition current_position = null; CheckpointPosition current_position = null;
@@ -306,7 +307,7 @@ public class CheckpointPosition
} }
} }


private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) _single_restore()
private (List<Operation>, Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) _single_restore()
{ {
var trackable = this.Trackable; var trackable = this.Trackable;
trackable._maybe_initialize_trackable(); trackable._maybe_initialize_trackable();
@@ -318,7 +319,7 @@ public class CheckpointPosition
} }
else else
{ {
return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(),
return (new List<Operation>(), new Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>(),
new List<CheckpointPosition>(), null); new List<CheckpointPosition>(), null);
} }
} }


+ 13
- 0
src/TensorFlowNET.Core/Eager/forwardprop_util.cs View File

@@ -0,0 +1,13 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Eager
{
public class TangentInfo
{
// TODO(Rinne): implement it.
public object Indices { get; set; }
public object Tangents { get; set; }
}
}

+ 66
- 4
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -1,6 +1,8 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics;
using System.Linq; using System.Linq;
using Tensorflow.Eager;
using Tensorflow.Framework.Models; using Tensorflow.Framework.Models;
using Tensorflow.Graphs; using Tensorflow.Graphs;
using Tensorflow.Train; using Tensorflow.Train;
@@ -17,11 +19,13 @@ namespace Tensorflow.Functions
internal FuncGraph func_graph; internal FuncGraph func_graph;
protected DelayedRewriteGradientFunctions _delayed_rewrite_functions; protected DelayedRewriteGradientFunctions _delayed_rewrite_functions;
protected Dictionary<string, string> _attrs; protected Dictionary<string, string> _attrs;
protected FunctionSpec _function_spec;
protected FunctionSpec _pre_initialized_function_spec = null;
internal ForwardBackwardCall forward_backward; internal ForwardBackwardCall forward_backward;
public Tensor[] Inputs => func_graph.Inputs; public Tensor[] Inputs => func_graph.Inputs;
public Tensor[] CapturedInputs => func_graph.external_captures; public Tensor[] CapturedInputs => func_graph.external_captures;


public string Name => _delayed_rewrite_functions.forward().Name;
public string Name => _delayed_rewrite_functions.Forward().Name;


public Tensor[] Outputs; public Tensor[] Outputs;
public Type ReturnType; public Type ReturnType;
@@ -175,7 +179,13 @@ namespace Tensorflow.Functions
var (forward_function, args_with_tangents) = forward_backward.Forward(); var (forward_function, args_with_tangents) = forward_backward.Forward();
Tensors flat_outputs = null; Tensors flat_outputs = null;
if (executing_eagerly) if (executing_eagerly)
{
flat_outputs = forward_function.Call(args_with_tangents);
}
else
{
flat_outputs = forward_function.Call(args_with_tangents); flat_outputs = forward_function.Call(args_with_tangents);
}
forward_backward.Record(flat_outputs); forward_backward.Record(flat_outputs);
return flat_outputs; return flat_outputs;
} }
@@ -186,7 +196,7 @@ namespace Tensorflow.Functions
{ {
g = ops.get_default_graph(); g = ops.get_default_graph();
} }
_delayed_rewrite_functions.forward().AddToGraph(g);
_delayed_rewrite_functions.Forward().AddToGraph(g);
} }


public void SetExternalCaptures(IEnumerable<Tensor> captures) public void SetExternalCaptures(IEnumerable<Tensor> captures)
@@ -196,8 +206,60 @@ namespace Tensorflow.Functions


ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly)
{ {
var functions = new FirstOrderTapeGradientFunctions(func_graph, false);
return new ForwardBackwardCall(functions, args, tape_watching: true);
TangentInfo input_tangents;
if (executing_eagerly)
{
throw new NotImplementedException();
}
else
{
input_tangents = new TangentInfo();
}
if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER)
{
if(input_tangents.Indices is not null || executing_eagerly)
{
var functions = new FirstOrderTapeGradientFunctions(func_graph, false);
return new ForwardBackwardCall(functions, args, tape_watching: true);
}
else
{
return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: true);
}
}
else if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER)
{
throw new NotImplementedException();
}

// TODO(Rinne): add arg "input_tagents" for ForwardBackwardCall.
return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: false);
}

internal void _set_function_spec(FunctionSpec spec)
{
_function_spec = null;
_pre_initialized_function_spec = spec;
_initialize_function_spec();
}

internal void _initialize_function_spec()
{
if(_pre_initialized_function_spec is null)
{
return;
}
Debug.Assert(_function_spec is null, "already initialized");
var spec = _pre_initialized_function_spec;
//var args = spec.Fullargspec.DictValue.Fields["args"];
// TODO(Rinne): self.structured_input_signature

_function_spec = new FunctionSpec()
{
Fullargspec = spec.Fullargspec,
IsMethod = spec.IsMethod,
InputSignature = spec.InputSignature
};
} }


public override string ToString() public override string ToString()


+ 50
- 3
src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs View File

@@ -5,6 +5,8 @@ using System.Linq;
using System.Text; using System.Text;
using Tensorflow.Contexts; using Tensorflow.Contexts;
using Tensorflow.Graphs; using Tensorflow.Graphs;
using Tensorflow.Operations;
using Tensorflow.Util;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow.Functions namespace Tensorflow.Functions
@@ -14,7 +16,10 @@ namespace Tensorflow.Functions
public int _num_outputs; public int _num_outputs;
FuncGraph _func_graph; FuncGraph _func_graph;
FunctionDef _definition; FunctionDef _definition;
Tensor[] _func_graph_outputs;
public string Name => _func_graph.FuncName; public string Name => _func_graph.FuncName;
public DataType[] OutputTypes { get; protected set; }
public Shape[] OutputShapes { get; protected set; }
public FunctionDef Definition public FunctionDef Definition
{ {
get get
@@ -36,27 +41,69 @@ namespace Tensorflow.Functions
var operations = graph.get_operations().Where(x => !input_ops.Contains(x.op)) var operations = graph.get_operations().Where(x => !input_ops.Contains(x.op))
.Select(x => x as Operation).ToArray(); .Select(x => x as Operation).ToArray();
var output_names = new string[0]; var output_names = new string[0];
OutputShapes = outputs.Select(x => x.shape).ToArray();
OutputTypes = outputs.Select(x => x.dtype.as_datatype_enum()).ToArray();


_func_graph = new FuncGraph(graph, name, attrs); _func_graph = new FuncGraph(graph, name, attrs);
_func_graph_outputs = new List<Tensor>(outputs).ToArray();
_func_graph.ToGraph(operations, inputs, outputs, output_names); _func_graph.ToGraph(operations, inputs, outputs, output_names);
} }


public Tensors Call(Tensors args) public Tensors Call(Tensors args)
{ {
// TODO(Rinne): Add arg `CancellationManager`.
// TODO(Rinne): Check the arg length.
var function_call_options = tf.Context.FunctionCallOptions;
string config;
if (string.IsNullOrEmpty(function_call_options.config_proto_serialized()))
{
config = function_utils.get_disabled_rewriter_config();
}
else
{
config = function_call_options.config_proto_serialized();
}
// TODO(Rinne): executor_type
var executing_eagerly = tf.Context.executing_eagerly();

var attrs = new object[] var attrs = new object[]
{ {
"executor_type", "", "executor_type", "",
"config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized()
}; };


var results = tf.Runner.TFE_Execute(tf.Context,
Tensor[] outputs;
if (executing_eagerly)
{
outputs = tf.Runner.TFE_Execute(tf.Context,
tf.Context.DeviceName, tf.Context.DeviceName,
_func_graph.FuncName, _func_graph.FuncName,
args, args,
attrs, attrs,
_num_outputs); _num_outputs);

return results;
}
else
{
tf.GradientTape().stop_recording();
outputs = functional_ops.partitioned_call(args, this, OutputTypes,
executing_eagerly, config, "");
}
foreach(var (i, func_graph_output) in enumerate(_func_graph_outputs))
{
handle_data_util.copy_handle_data(func_graph_output, outputs[i]);
}
if (executing_eagerly)
{
return outputs;
}
else
{
foreach(var (i, shape) in enumerate(OutputShapes))
{
outputs[i].shape = shape;
}
return outputs;
}
} }


public void AddToGraph(Graph g = null) public void AddToGraph(Graph g = null)


+ 35
- 5
src/TensorFlowNET.Core/Functions/Function.cs View File

@@ -9,16 +9,46 @@ namespace Tensorflow
#pragma warning disable CS0169 // The field 'Function._handle' is never used #pragma warning disable CS0169 // The field 'Function._handle' is never used
private IntPtr _handle; private IntPtr _handle;
#pragma warning restore CS0169 // The field 'Function._handle' is never used #pragma warning restore CS0169 // The field 'Function._handle' is never used

protected Func<Tensors, Tensors> _function;
protected ConcreteFunction _concrete_variable_creation_fn;
protected bool _auto_graph;
public string Name { get; set; } public string Name { get; set; }
public Function()
public Function(Func<Tensors, Tensors> function,
string name, bool auto_graph = true)
{
_function = function;
Name = name;
_auto_graph = auto_graph;
}

public virtual Tensors Apply(Tensors inputs)
{ {
if (_run_functions_eagerly())
{
return _function(inputs);
}


var result = _call(inputs);
return result;
} }
public Function(string name)
protected virtual Tensors _call(Tensors inputs)
{ {
Name = name;
_initialize();

return _concrete_variable_creation_fn.CallFlat(inputs,
_concrete_variable_creation_fn.CapturedInputs);
}

protected virtual bool _run_functions_eagerly()
{
return false;
}

private void _initialize()
{

} }
} }
} }

+ 9
- 7
src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs View File

@@ -15,11 +15,11 @@ namespace Tensorflow.Functions
/// </summary> /// </summary>
public abstract class TapeGradientFunctions public abstract class TapeGradientFunctions
{ {
string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name";
string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name";
string _FORWARD_PREFIX = "__forward_";
string _BACKWARD_PREFIX = "__backward_";
string _INFERENCE_PREFIX = "__inference_";
protected string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name";
protected string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name";
protected string _FORWARD_PREFIX = "__forward_";
protected string _BACKWARD_PREFIX = "__backward_";
protected string _INFERENCE_PREFIX = "__inference_";


protected FuncGraph _func_graph; protected FuncGraph _func_graph;
protected EagerDefinedFunction _forward; protected EagerDefinedFunction _forward;
@@ -35,8 +35,9 @@ namespace Tensorflow.Functions
_func_graph = func_graph; _func_graph = func_graph;
} }


public EagerDefinedFunction Forward(Tensors inference_args)
public virtual EagerDefinedFunction Forward(Tensors inference_args, Tensors input_tangents = null)
{ {
// TODO(Rinne): add input_tangents arg.
return ForwardAndBackwardFunctions(inference_args); return ForwardAndBackwardFunctions(inference_args);
} }


@@ -45,8 +46,9 @@ namespace Tensorflow.Functions
/// </summary> /// </summary>
/// <param name="flat_outputs"></param> /// <param name="flat_outputs"></param>
/// <param name="inference_args"></param> /// <param name="inference_args"></param>
public void Record(Tensors flat_outputs, Tensors inference_args)
public virtual void Record(Tensors flat_outputs, Tensors inference_args)
{ {
// TODO(Rinne): add arg `input_tagents`.
var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs); var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs);
tf.Runner.RecordGradient(_forward.Name, inference_args, new object[0], to_record, tf.Runner.RecordGradient(_forward.Name, inference_args, new object[0], to_record,
getBackwardFunction: backward_function); getBackwardFunction: backward_function);


+ 1
- 0
src/TensorFlowNET.Core/Functions/function_saved_model_utils.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Text; using System.Text;
using Tensorflow.Operations; using Tensorflow.Operations;
using Tensorflow.Train; using Tensorflow.Train;
using Tensorflow.Variables;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow.Functions namespace Tensorflow.Functions


+ 20
- 7
src/TensorFlowNET.Core/Functions/monomorphic_function.cs View File

@@ -5,16 +5,13 @@ using Tensorflow.Graphs;


namespace Tensorflow.Functions namespace Tensorflow.Functions
{ {
public class DelayedRewriteGradientFunctions
public class DelayedRewriteGradientFunctions: TapeGradientFunctions
{ {
static readonly string _INFERENCE_PREFIX = "__inference_";
static readonly string _BACKWARD_PREFIX = "__backward_";
static readonly string _FORWARD_PREFIX = "__forward_";
FuncGraph _func_graph;
EagerDefinedFunction _inference_function; EagerDefinedFunction _inference_function;
Dictionary<string, string> _attrs; Dictionary<string, string> _attrs;
int _num_inference_outputs; int _num_inference_outputs;
public DelayedRewriteGradientFunctions(FuncGraph func_graph, Dictionary<string, string> attrs) public DelayedRewriteGradientFunctions(FuncGraph func_graph, Dictionary<string, string> attrs)
:base(func_graph, false)
{ {
_func_graph= func_graph; _func_graph= func_graph;
_inference_function = new EagerDefinedFunction(_inference_name(_func_graph.Name), _inference_function = new EagerDefinedFunction(_inference_name(_func_graph.Name),
@@ -23,7 +20,7 @@ namespace Tensorflow.Functions
_num_inference_outputs = _func_graph.Outputs.Length; _num_inference_outputs = _func_graph.Outputs.Length;
} }


public EagerDefinedFunction forward(Tensors inference_args = null, Tensors input_tangents = null)
public override EagerDefinedFunction Forward(Tensors inference_args = null, Tensors input_tangents = null)
{ {
if(input_tangents is not null) if(input_tangents is not null)
{ {
@@ -33,7 +30,23 @@ namespace Tensorflow.Functions
return _inference_function; return _inference_function;
} }


private static string _inference_name(string name)
public override void Record(Tensors flat_outputs, Tensors inference_args)
{
// TODO(Rinne): implement it.
throw new NotImplementedException();
base.Record(flat_outputs, inference_args);
}

//private (BackwardFunction, Tensors) _backward(Tensors outputs)
//{
// Tensor[] backward_function(Tensor[] grads, long[] unneeded_gradients)
// {
// var call_op = outputs[0].op;

// }
//}

private string _inference_name(string name)
{ {
return $"{_INFERENCE_PREFIX}{name}_{ops.uid()}"; return $"{_INFERENCE_PREFIX}{name}_{ops.uid()}";
} }


+ 5
- 0
src/TensorFlowNET.Core/Gradients/gradients_util.cs View File

@@ -25,6 +25,11 @@ namespace Tensorflow
{ {
public class gradients_util public class gradients_util
{ {
// Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are
// unfortunately too slow to use here.
public static int POSSIBLE_GRADIENT_TYPES_NONE = 0;
public static int POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1;
public static int POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2;
public static Tensor[] _GradientsHelper(Tensor[] ys, public static Tensor[] _GradientsHelper(Tensor[] ys,
Tensor[] xs, Tensor[] xs,
Tensor[] grad_ys = null, Tensor[] grad_ys = null,


+ 1
- 0
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -129,6 +129,7 @@ namespace Tensorflow
protected Graph outer_graph; protected Graph outer_graph;
public Graph OuterGraph => outer_graph; public Graph OuterGraph => outer_graph;
public Dictionary<string, EagerDefinedFunction> Functions => _functions; public Dictionary<string, EagerDefinedFunction> Functions => _functions;
public SafeGraphHandle c_graph => _handle;


public Graph() public Graph()
{ {


+ 4
- 0
src/TensorFlowNET.Core/Operations/c_api.ops.cs View File

@@ -208,5 +208,9 @@ namespace Tensorflow


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, SafeStatusHandle status); public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
public static extern IntPtr GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output);
[DllImport(TensorFlowLibName)]
public static extern void SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data);
} }
} }

+ 70
- 0
src/TensorFlowNET.Core/Operations/functional_ops.cs View File

@@ -14,10 +14,14 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using Google.Protobuf;
using Google.Protobuf.WellKnownTypes;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using Tensorflow.Framework; using Tensorflow.Framework;
using Tensorflow.Functions;
using Tensorflow.Operations;
using Tensorflow.Util; using Tensorflow.Util;
using static Tensorflow.Binding; using static Tensorflow.Binding;


@@ -25,6 +29,72 @@ namespace Tensorflow
{ {
public class functional_ops public class functional_ops
{ {
public static Tensor[] partitioned_call(Tensors args, EagerDefinedFunction f, DataType[] tout,
bool executing_eagerly, string config, string executor_type)
{
if (tout is null)
{
throw new NotImplementedException();
}

if (config is null)
{
config = function_utils.get_disabled_rewriter_config();
}

if (executor_type is null)
{
executor_type = "";
}

if (executing_eagerly)
{
throw new NotImplementedException();
}

var converted_args = args.Select(x => ops.convert_to_tensor(x)).ToArray();
AttrValue tin_attr = new()
{
List = new AttrValue.Types.ListValue()
};
tin_attr.List.Type.AddRange(args.Select(x => x.dtype.as_datatype_enum()));
AttrValue tout_attr = new()
{
List = new AttrValue.Types.ListValue()
};
tout_attr.List.Type.AddRange(tout);
AttrValue func_attr = new()
{
Func = new NameAttrList()
};
func_attr.Func.Name = f.Name;
AttrValue executor_type_attr = new AttrValue()
{
S = tf.compat.as_bytes(executor_type)
};
AttrValue config_proto = new AttrValue()
{
S = ByteString.CopyFromUtf8(executor_type)
};

var graph = ops.get_default_graph();
f.AddToGraph(graph);
// TODO(Rinne): complete it with `f.stateful`
var op_name = "PartitionedCall";
string xla_compile_attr = "_XlaMustCompile";
Dictionary<string, AttrValue> op_attrs = new();
op_attrs["Tin"] = tin_attr;
op_attrs["Tout"] = tout_attr;
op_attrs["f"] = func_attr;
op_attrs["config_proto"] = config_proto;
op_attrs["executor_type"] = executor_type_attr;
// TODO(Rinne): deal with `f.definition`.
var op = graph.create_op(op_name, args, tout.Select(x => x.as_tf_dtype()).ToArray(),
name: op_name, attrs: op_attrs);
var outputs = op.outputs;
// TODO(Rinne): deal with `f.graph`.
return outputs;
}
public static Tensor scan( public static Tensor scan(
Func<Tensor, Tensor, Tensor> fn, Func<Tensor, Tensor, Tensor> fn,
Tensor elems, Tensor elems,


+ 83
- 0
src/TensorFlowNET.Core/Operations/gen_functional_ops.cs View File

@@ -0,0 +1,83 @@
using System;
using System.Collections.Generic;
using System.Text;
using System.Xml.Linq;
using Tensorflow.Contexts;
using Tensorflow.Eager;
using Tensorflow.Functions;
using static Tensorflow.Binding;

namespace Tensorflow.Operations
{
public class gen_functional_ops
{
public static Tensor[] partitioned_call(Tensors args, TF_DataType[] tout, EagerDefinedFunction f,
string config = "", string config_proto = "", string executor_type = "", string name = null)
{
var ctx = tf.Context;
if (ctx.executing_eagerly())
{
try
{
return tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("PartitionedCall", name,
args, tout, f, config, config_proto, executor_type));
}
catch (Exception)
{

}
}

if (config is null)
{
config = "";
}
if (config_proto is null)
{
config_proto = "";
}
if (executor_type is null)
{
executor_type = "";
}
Dictionary<string, object> kwargs = new();
kwargs["args"] = args;
kwargs["Tout"] = tout;
kwargs["f"] = f;
kwargs["config"] = config;
kwargs["config_proto"] = config_proto;
kwargs["executor_type"] = executor_type;
var output = tf.OpDefLib._apply_op_helper("PartitionedCall",
name, kwargs);
var result = output.outputs;
if (execute.must_record_gradient())
{
throw new NotImplementedException();
}
return result;
}

public static Tensor[] partitioned_call_eager_fallback(Tensors args, TF_DataType[] tout, EagerDefinedFunction f,
string config, string config_proto, string executor_type, string name, Context ctx)
{
// TODO(Rinne): implement it.
throw new NotImplementedException();
if(config is null)
{
config = "";
}
if(config_proto is null)
{
config_proto = "";
}
if(executor_type is null)
{
executor_type = "";
}
object[] attrs = new object[]
{

};
}
}
}

+ 20
- 5
src/TensorFlowNET.Core/Operations/handle_data_util.cs View File

@@ -1,7 +1,9 @@
using System;
using Google.Protobuf;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using Tensorflow.Eager; using Tensorflow.Eager;
using static Tensorflow.CppShapeInferenceResult.Types;


namespace Tensorflow.Operations namespace Tensorflow.Operations
{ {
@@ -11,18 +13,31 @@ namespace Tensorflow.Operations
{ {
if(target_t.dtype == dtypes.resource || target_t.dtype == dtypes.variant) if(target_t.dtype == dtypes.resource || target_t.dtype == dtypes.variant)
{ {
SafeTensorHandle handle_data;
HandleData handle_data;
if(source_t is EagerTensor) if(source_t is EagerTensor)
{ {
handle_data = source_t.Handle;
handle_data = source_t.HandleData;
} }
else else
{ {
handle_data = ops.get_resource_handle_data(source_t); handle_data = ops.get_resource_handle_data(source_t);
} }
throw new NotImplementedException();
//if(handle_data is not null && handle_data.)
if(handle_data is not null && handle_data.IsSet && handle_data.ShapeAndType is not null
&& handle_data.ShapeAndType.Count > 0)
{
set_handle_data(target_t, handle_data);
}
}
}

public static void set_handle_data(Tensor target_t, HandleData handle_data)
{
if(target_t is EagerTensor)
{
target_t.HandleData = handle_data;
return;
} }
c_api.SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), handle_data.ToByteArray());
} }
} }
} }

+ 17
- 1
src/TensorFlowNET.Core/Operations/resource_variable_ops.cs View File

@@ -39,7 +39,7 @@ namespace Tensorflow


public static bool is_resource_variable(IVariableV1 var) public static bool is_resource_variable(IVariableV1 var)
{ {
return var is ResourceVariable;
return var is BaseResourceVariable;
} }
public static bool is_resource_variable(Trackable var) public static bool is_resource_variable(Trackable var)
@@ -231,5 +231,21 @@ namespace Tensorflow
} }
} }
} }

public static void _maybe_set_handle_data(TF_DataType dtype, Tensor handle, Tensor tensor)
{
if(dtype == dtypes.variant)
{
var handle_data = get_eager_safe_handle_data(handle);
if(handle_data.IsSet && handle_data.ShapeAndType.Count > 1)
{
tensor.HandleData = new HandleData()
{
IsSet = true
};
tensor.HandleData.ShapeAndType.AddRange(handle_data.ShapeAndType.Skip(1));
}
}
}
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Protobuf/CppShapeInference.cs View File

@@ -479,7 +479,7 @@ namespace Tensorflow {
/// </summary> /// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<global::Tensorflow.CppShapeInferenceResult.Types.HandleShapeAndType> ShapeAndType { public pbc::RepeatedField<global::Tensorflow.CppShapeInferenceResult.Types.HandleShapeAndType> ShapeAndType {
get { return shapeAndType_; }
get { return shapeAndType_; }
} }


[global::System.Diagnostics.DebuggerNonUserCodeAttribute] [global::System.Diagnostics.DebuggerNonUserCodeAttribute]


+ 16
- 9
src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs View File

@@ -277,15 +277,15 @@ namespace Tensorflow {
get { return Descriptor; } get { return Descriptor; }
} }


[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public SavedObject() {
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public SavedObject() {
OnConstruction(); OnConstruction();
} }


partial void OnConstruction(); partial void OnConstruction();


[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public SavedObject(SavedObject other) : this() {
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public SavedObject(SavedObject other) : this() {
children_ = other.children_.Clone(); children_ = other.children_.Clone();
dependencies_ = other.dependencies_.Clone(); dependencies_ = other.dependencies_.Clone();
slotVariables_ = other.slotVariables_.Clone(); slotVariables_ = other.slotVariables_.Clone();
@@ -329,7 +329,9 @@ namespace Tensorflow {
public const int ChildrenFieldNumber = 1; public const int ChildrenFieldNumber = 1;
private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_children_codec private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_children_codec
= pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser);
private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>();
private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_dependencies_codec
= pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser);
private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>();
private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> dependencies_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> dependencies_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>();
/// <summary> /// <summary>
/// Objects which this object depends on: named edges in the dependency /// Objects which this object depends on: named edges in the dependency
@@ -501,7 +503,8 @@ namespace Tensorflow {
return true; return true;
} }
if(!children_.Equals(other.children_)) return false; if(!children_.Equals(other.children_)) return false;
if(!slotVariables_.Equals(other.slotVariables_)) return false;
if (!dependencies_.Equals(other.dependencies_)) return false;
if (!slotVariables_.Equals(other.slotVariables_)) return false;
if (!object.Equals(UserObject, other.UserObject)) return false; if (!object.Equals(UserObject, other.UserObject)) return false;
if (!object.Equals(Asset, other.Asset)) return false; if (!object.Equals(Asset, other.Asset)) return false;
if (!object.Equals(Function, other.Function)) return false; if (!object.Equals(Function, other.Function)) return false;
@@ -519,6 +522,7 @@ namespace Tensorflow {
public override int GetHashCode() { public override int GetHashCode() {
int hash = 1; int hash = 1;
hash ^= children_.GetHashCode(); hash ^= children_.GetHashCode();
hash ^= dependencies_.GetHashCode();
hash ^= slotVariables_.GetHashCode(); hash ^= slotVariables_.GetHashCode();
if (kindCase_ == KindOneofCase.UserObject) hash ^= UserObject.GetHashCode(); if (kindCase_ == KindOneofCase.UserObject) hash ^= UserObject.GetHashCode();
if (kindCase_ == KindOneofCase.Asset) hash ^= Asset.GetHashCode(); if (kindCase_ == KindOneofCase.Asset) hash ^= Asset.GetHashCode();
@@ -544,6 +548,7 @@ namespace Tensorflow {
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) { public void WriteTo(pb::CodedOutputStream output) {
children_.WriteTo(output, _repeated_children_codec); children_.WriteTo(output, _repeated_children_codec);
children_.WriteTo(output, _repeated_dependencies_codec);
slotVariables_.WriteTo(output, _repeated_slotVariables_codec); slotVariables_.WriteTo(output, _repeated_slotVariables_codec);
if (kindCase_ == KindOneofCase.UserObject) { if (kindCase_ == KindOneofCase.UserObject) {
output.WriteRawTag(34); output.WriteRawTag(34);
@@ -587,6 +592,7 @@ namespace Tensorflow {
public int CalculateSize() { public int CalculateSize() {
int size = 0; int size = 0;
size += children_.CalculateSize(_repeated_children_codec); size += children_.CalculateSize(_repeated_children_codec);
size += children_.CalculateSize(_repeated_dependencies_codec);
size += slotVariables_.CalculateSize(_repeated_slotVariables_codec); size += slotVariables_.CalculateSize(_repeated_slotVariables_codec);
if (kindCase_ == KindOneofCase.UserObject) { if (kindCase_ == KindOneofCase.UserObject) {
size += 1 + pb::CodedOutputStream.ComputeMessageSize(UserObject); size += 1 + pb::CodedOutputStream.ComputeMessageSize(UserObject);
@@ -619,7 +625,7 @@ namespace Tensorflow {
return size; return size;
} }


[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
//[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(SavedObject other) { public void MergeFrom(SavedObject other) {
if (other == null) { if (other == null) {
return; return;
@@ -682,7 +688,7 @@ namespace Tensorflow {
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
} }


[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
//[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(pb::CodedInputStream input) { public void MergeFrom(pb::CodedInputStream input) {
uint tag; uint tag;
while ((tag = input.ReadTag()) != 0) { while ((tag = input.ReadTag()) != 0) {
@@ -692,9 +698,10 @@ namespace Tensorflow {
break; break;
case 10: { case 10: {
children_.AddEntriesFrom(input, _repeated_children_codec); children_.AddEntriesFrom(input, _repeated_children_codec);
dependencies_.AddRange(children_.Except(dependencies_));
break; break;
} }
case 26: {
case 26: {
slotVariables_.AddEntriesFrom(input, _repeated_slotVariables_codec); slotVariables_.AddEntriesFrom(input, _repeated_slotVariables_codec);
break; break;
} }


+ 5
- 0
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -109,7 +109,12 @@ https://tensorflownet.readthedocs.io</Description>
<ItemGroup> <ItemGroup>
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.2" /> <PackageReference Include="Newtonsoft.Json" Version="13.0.2" />
<PackageReference Include="OneOf" Version="3.0.223" />
<PackageReference Include="Protobuf.Text" Version="0.6.1" /> <PackageReference Include="Protobuf.Text" Version="0.6.1" />
<PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" /> <PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" />
</ItemGroup> </ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\Tensorflow.Common\Tensorflow.Common.csproj" />
</ItemGroup>
</Project> </Project>

+ 1
- 0
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -87,6 +87,7 @@ namespace Tensorflow
public object Tag { get; set; } public object Tag { get; set; }
protected new SafeTensorHandle _handle; protected new SafeTensorHandle _handle;
public virtual SafeTensorHandle Handle => _handle; public virtual SafeTensorHandle Handle => _handle;
public Tensorflow.CppShapeInferenceResult.Types.HandleData HandleData { get; internal set; }


protected SafeEagerTensorHandle _eagerTensorHandle; protected SafeEagerTensorHandle _eagerTensorHandle;
/// <summary> /// <summary>


+ 4
- 3
src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs View File

@@ -14,18 +14,19 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using OneOf;
using Tensorflow.Checkpoint; using Tensorflow.Checkpoint;


namespace Tensorflow namespace Tensorflow
{ {
public class MySaveableObject public class MySaveableObject
{ {
protected Maybe<Tensor, BaseResourceVariable> _op;
protected OneOf<Tensor, BaseResourceVariable> _op;
public Tensor op public Tensor op
{ {
get get
{ {
if(_op.TryGet<Tensor>(out var tensor))
if(_op.TryPickT0(out var tensor, out var _))
{ {
return tensor; return tensor;
} }
@@ -43,7 +44,7 @@ namespace Tensorflow
{ {
get get
{ {
if (_op.TryGet<BaseResourceVariable>(out var v))
if (_op.TryPickT1(out var v, out var _))
{ {
return v; return v;
} }


+ 50
- 3
src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs View File

@@ -25,11 +25,32 @@ namespace Tensorflow.Training.Saving.SavedModel
/// <param name="saved_concrete_function"></param> /// <param name="saved_concrete_function"></param>
/// <param name="concrete_functions"></param> /// <param name="concrete_functions"></param>
/// <returns></returns> /// <returns></returns>
public static ConcreteFunction recreate_function(SavedFunction saved_concrete_function,
public static Function recreate_function(SavedFunction saved_function,
IDictionary<string, ConcreteFunction> concrete_functions) IDictionary<string, ConcreteFunction> concrete_functions)
{ {
var function_spec = _deserialize_function_spec_as_nonmethod(saved_concrete_function.FunctionSpec);
return null;
var function_spec = _deserialize_function_spec_as_nonmethod(saved_function.FunctionSpec);

List<ConcreteFunction> concrete_function_objects = new();
foreach(var concrete_function_name in saved_function.ConcreteFunctions)
{
concrete_function_objects.Add(concrete_functions[concrete_function_name]);
}
foreach(var cf in concrete_function_objects)
{
cf._set_function_spec(function_spec);
}

foreach(var function_name in saved_function.ConcreteFunctions)
{
var function = concrete_functions[function_name];
if(_concrete_function_callable_with(function, null, false))
{
return new RestoredFunction(null, function, "function_from_deserialization");
}
}
return new RestoredFunction(x => x, new ConcreteFunction(x => x, TF_DataType.TF_FLOAT), "function_return_itself");
//throw new ValueError("Unexpected runtime behavior, please submit an issue to " +
// "https://github.com/SciSharp/TensorFlow.NET/issues");
} }


public static Dictionary<string, ConcreteFunction> load_function_def_library(FunctionDefLibrary library, public static Dictionary<string, ConcreteFunction> load_function_def_library(FunctionDefLibrary library,
@@ -385,5 +406,31 @@ namespace Tensorflow.Training.Saving.SavedModel
JitCompile = function_spec_proto.JitCompile JitCompile = function_spec_proto.JitCompile
}; };
} }

private static Tensors _call_concrete_function(ConcreteFunction function, Tensors inputs)
{
// TODO(Rinne): var expected_structure = function.func_graph.structured_input_signature
return function.CallFlat(inputs, function.CapturedInputs);
}

private static bool _concrete_function_callable_with(ConcreteFunction function, Tensors inputs, bool allow_conversion)
{
// TODO(Rinne): revise it.
return true;
}
}

public class RestoredFunction : Function
{
public RestoredFunction(Func<Tensors, Tensors> function, ConcreteFunction concrete_function,
string name, bool auto_graph = true): base(function, name, auto_graph)
{
_concrete_variable_creation_fn = concrete_function;
}

protected override bool _run_functions_eagerly()
{
return false;
}
} }
} }

+ 22
- 14
src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs View File

@@ -14,6 +14,7 @@ using Tensorflow.Variables;
using Tensorflow.Functions; using Tensorflow.Functions;
using Tensorflow.Training.Saving.SavedModel; using Tensorflow.Training.Saving.SavedModel;
using Tensorflow.Trackables; using Tensorflow.Trackables;
using OneOf;


namespace Tensorflow namespace Tensorflow
{ {
@@ -44,6 +45,8 @@ namespace Tensorflow
_asset_file_def = meta_graph.AssetFileDef; _asset_file_def = meta_graph.AssetFileDef;
_operation_attributes = meta_graph.GraphDef.Node.ToDictionary(x => x.Name, x => x.Attr); _operation_attributes = meta_graph.GraphDef.Node.ToDictionary(x => x.Name, x => x.Attr);
_proto = object_graph_proto; _proto = object_graph_proto;
// Debug(Rinne)
var temp = _proto.ToString();
_export_dir = export_dir; _export_dir = export_dir;
// TODO: `this._concrete_functions` and `this._restored_concrete_functions` // TODO: `this._concrete_functions` and `this._restored_concrete_functions`
_concrete_functions = function_deserialization.load_function_def_library( _concrete_functions = function_deserialization.load_function_def_library(
@@ -259,9 +262,9 @@ namespace Tensorflow
/// </summary> /// </summary>
/// <param name="proto"></param> /// <param name="proto"></param>
/// <returns></returns> /// <returns></returns>
private Dictionary<Maybe<string, int>, int> _get_node_dependencies(SavedObject proto)
private Dictionary<OneOf<string, int>, int> _get_node_dependencies(SavedObject proto)
{ {
Dictionary<Maybe<string, int>, int> dependencies = new();
Dictionary<OneOf<string, int>, int> dependencies = new();
foreach(var refer in proto.Dependencies) foreach(var refer in proto.Dependencies)
{ {
dependencies[refer.LocalName] = refer.NodeId; dependencies[refer.LocalName] = refer.NodeId;
@@ -375,11 +378,6 @@ namespace Tensorflow
// Re-create everything. // Re-create everything.
foreach (var (node_id, proto) in _iter_all_nodes()) foreach (var (node_id, proto) in _iter_all_nodes())
{ {
if(node_id == 45)
{
// TODelete
Console.WriteLine();
}
if (nodes.ContainsKey(node_id)) if (nodes.ContainsKey(node_id))
{ {
continue; continue;
@@ -474,7 +472,7 @@ namespace Tensorflow
} }
} }


private void _setup_function_captures(string concrete_function_name, IDictionary<Maybe<string, int>, Trackable> nodes)
private void _setup_function_captures(string concrete_function_name, IDictionary<OneOf<string, int>, Trackable> nodes)
{ {
if (_restored_concrete_functions.Contains(concrete_function_name)) if (_restored_concrete_functions.Contains(concrete_function_name))
{ {
@@ -509,6 +507,11 @@ namespace Tensorflow
/// <param name="node_id"></param> /// <param name="node_id"></param>
private void _add_object_graph_edges(SavedObject proto, int node_id) private void _add_object_graph_edges(SavedObject proto, int node_id)
{ {
// Debug(Rinne)
if(node_id == 1)
{
Console.WriteLine();
}
var obj = _nodes[node_id]; var obj = _nodes[node_id];
var setter = _node_setters[node_id]; var setter = _node_setters[node_id];


@@ -549,8 +552,13 @@ namespace Tensorflow
private (Trackable, Action<object, object, object>) _recreate(SavedObject proto, int node_id, IDictionary<int, Trackable> nodes) private (Trackable, Action<object, object, object>) _recreate(SavedObject proto, int node_id, IDictionary<int, Trackable> nodes)
{ {
// skip the registered classes. // skip the registered classes.
if(node_id == 16)
{
// Debug(Rinne)
Console.WriteLine();
}


Dictionary<Maybe<string, int>, Trackable> dependencies = new();
Dictionary<OneOf<string, int>, Trackable> dependencies = new();
foreach(var item in _get_node_dependencies(proto)) foreach(var item in _get_node_dependencies(proto))
{ {
dependencies[item.Key] = nodes[item.Value]; dependencies[item.Key] = nodes[item.Value];
@@ -571,7 +579,7 @@ namespace Tensorflow
/// <param name="proto"></param> /// <param name="proto"></param>
/// <param name="node_id"></param> /// <param name="node_id"></param>
/// <param name="dependencies"></param> /// <param name="dependencies"></param>
private (Trackable, Action<object, object, object>) _recreate_default(SavedObject proto, int node_id, IDictionary<Maybe<string, int>, Trackable> dependencies)
private (Trackable, Action<object, object, object>) _recreate_default(SavedObject proto, int node_id, IDictionary<OneOf<string, int>, Trackable> dependencies)
{ {
return proto.KindCase switch return proto.KindCase switch
{ {
@@ -637,10 +645,10 @@ namespace Tensorflow
} }
} }


private (ConcreteFunction, Action<object, object, object>) _recreate_function(SavedFunction proto,
Dictionary<Maybe<string, int>, Trackable> dependencies)
private (Function, Action<object, object, object>) _recreate_function(SavedFunction proto,
Dictionary<OneOf<string, int>, Trackable> dependencies)
{ {
var fn = function_deserialization.recreate_function(proto, null);
var fn = function_deserialization.recreate_function(proto, _concrete_functions);
foreach (var name in proto.ConcreteFunctions) foreach (var name in proto.ConcreteFunctions)
{ {
_setup_function_captures(name, dependencies); _setup_function_captures(name, dependencies);
@@ -649,7 +657,7 @@ namespace Tensorflow
} }


private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto,
IDictionary<Maybe<string, int>, Trackable> dependencies)
IDictionary<OneOf<string, int>, Trackable> dependencies)
{ {
var fn = function_deserialization.setup_bare_concrete_function(proto, _concrete_functions); var fn = function_deserialization.setup_bare_concrete_function(proto, _concrete_functions);
_setup_function_captures(proto.ConcreteFunctionName, dependencies); _setup_function_captures(proto.ConcreteFunctionName, dependencies);


+ 18
- 18
src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs View File

@@ -14,6 +14,7 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using OneOf;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics; using System.Diagnostics;
@@ -174,7 +175,7 @@ namespace Tensorflow
full_name = name + "_" + attr; full_name = name + "_" + attr;
} }
var op = factory(full_name); var op = factory(full_name);
if(op.TryGet<BaseResourceVariable>(out var variable))
if(op.TryPickT0(out var variable, out var saveable))
{ {
foreach (var v in saveable_objects_for_op(variable as Trackable, variable.Name)) foreach (var v in saveable_objects_for_op(variable as Trackable, variable.Name))
{ {
@@ -183,7 +184,6 @@ namespace Tensorflow
} }
else else
{ {
var saveable = op.GetValue<MySaveableObject>();
foreach (var v in saveable_objects_for_op(saveable, saveable.name)) foreach (var v in saveable_objects_for_op(saveable, saveable.name))
{ {
yield return v; yield return v;
@@ -252,11 +252,11 @@ namespace Tensorflow
return names_to_saveables; return names_to_saveables;
} }


public static IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_objects_from_trackable(Trackable obj)
public static IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> saveable_objects_from_trackable(Trackable obj)
{ {
// skip the process of type `PythonState` // skip the process of type `PythonState`


Maybe<BaseResourceVariable, MySaveableObject> create_saveable(string name = "")
OneOf<BaseResourceVariable, MySaveableObject> create_saveable(string name = "")
{ {
// skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`. // skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`.
var tensor_dict = obj.serialize_to_tensors(); var tensor_dict = obj.serialize_to_tensors();
@@ -272,14 +272,14 @@ namespace Tensorflow
string spec_name = name + TrackableUtils.escape_local_name(tensor_name); string spec_name = name + TrackableUtils.escape_local_name(tensor_name);


IDictionary<string, Tensor> internal_dict; IDictionary<string, Tensor> internal_dict;
if (maybe_tensor.TryGet<Tensor>(out var tensor))
if (maybe_tensor.TryPickT0(out var tensor, out var dic))
{ {
internal_dict = new Dictionary<string, Tensor>(); internal_dict = new Dictionary<string, Tensor>();
internal_dict[""] = tensor; internal_dict[""] = tensor;
} }
else else
{ {
internal_dict = maybe_tensor.GetValue<IDictionary<string, Tensor>>();
internal_dict = dic;
} }


foreach (var item in internal_dict) foreach (var item in internal_dict)
@@ -292,7 +292,7 @@ namespace Tensorflow


if (trackable_has_serialize_to_tensor(obj)) if (trackable_has_serialize_to_tensor(obj))
{ {
Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> res = new();
Dictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> res = new();
res[TrackableUtils.SERIALIZE_TO_TENSORS_NAME] = create_saveable; res[TrackableUtils.SERIALIZE_TO_TENSORS_NAME] = create_saveable;
return res; return res;
} }
@@ -316,9 +316,9 @@ namespace Tensorflow
/// Converts a list of SaveableObjects to a tensor dictionary. /// Converts a list of SaveableObjects to a tensor dictionary.
/// </summary> /// </summary>
/// <param name="saveables"></param> /// <param name="saveables"></param>
public static Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> saveable_object_to_tensor_dict(IList<MySaveableObject> saveables)
public static Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> saveable_object_to_tensor_dict(IList<MySaveableObject> saveables)
{ {
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new();
Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> tensor_dict = new();
foreach (var saveable in saveables) foreach (var saveable in saveables)
{ {
foreach (var spec in saveable.specs) foreach (var spec in saveable.specs)
@@ -328,7 +328,7 @@ namespace Tensorflow
var slice_spec = convert_to_string(spec.slice_spec); var slice_spec = convert_to_string(spec.slice_spec);
if (!string.IsNullOrEmpty(slice_spec)) if (!string.IsNullOrEmpty(slice_spec))
{ {
tensor_dict.SetDefault(name, new Dictionary<string, Tensor>()).GetValue<IDictionary<string, Tensor>>()[slice_spec] = spec.tensor;
tensor_dict.SetDefault(name, new Dictionary<string, Tensor>()).AsT1[slice_spec] = spec.tensor;
} }
else else
{ {
@@ -343,7 +343,7 @@ namespace Tensorflow
/// Generates `Trackable._restore_from_tensors` from SaveableObjects. /// Generates `Trackable._restore_from_tensors` from SaveableObjects.
/// </summary> /// </summary>
/// <returns></returns> /// <returns></returns>
public static Func<IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>, IDictionary<string, Operation>> saveable_object_to_restore_fn(IList<MySaveableObject> saveables)
public static Func<IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>, IDictionary<string, Operation>> saveable_object_to_restore_fn(IList<MySaveableObject> saveables)
{ {
return (restored_tensors) => return (restored_tensors) =>
{ {
@@ -359,14 +359,14 @@ namespace Tensorflow


var maybe_tensor = restored_tensors[name]; var maybe_tensor = restored_tensors[name];
IDictionary<string, Tensor> dict; IDictionary<string, Tensor> dict;
if(maybe_tensor.TryGet<Tensor>(out var tensor))
if(maybe_tensor.TryPickT0(out var tensor, out var dic))
{ {
dict = new Dictionary<string, Tensor>(); dict = new Dictionary<string, Tensor>();
dict[""] = tensor; dict[""] = tensor;
} }
else else
{ {
dict = maybe_tensor.GetValue<IDictionary<string, Tensor>>();
dict = dic;
} }
saveable_restored_tensors.Add(dict[slice_spec]); saveable_restored_tensors.Add(dict[slice_spec]);
} }
@@ -381,18 +381,18 @@ namespace Tensorflow
/// </summary> /// </summary>
/// <param name="saveable_fn_by_name"></param> /// <param name="saveable_fn_by_name"></param>
/// <param name="temp_session"></param> /// <param name="temp_session"></param>
public static IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> recreate_saveable_objects(
public static IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> recreate_saveable_objects(
IDictionary<string, (Trackable, Trackable)> saveable_fn_by_name, IEnumerable<object>? temp_session) IDictionary<string, (Trackable, Trackable)> saveable_fn_by_name, IEnumerable<object>? temp_session)
{ {
if (saveable_fn_by_name.Count > 0) if (saveable_fn_by_name.Count > 0)
{ {
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues");
} }
var res = new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>();
var res = new Dictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>>();
return res; return res;
} }


public static Maybe<BaseResourceVariable, MySaveableObject> create_saveable_object(string name, string key, Func<string, Maybe<BaseResourceVariable, MySaveableObject>> factory,
public static OneOf<BaseResourceVariable, MySaveableObject> create_saveable_object(string name, string key, Func<string, OneOf<BaseResourceVariable, MySaveableObject>> factory,
bool call_with_mapped_captures = false) bool call_with_mapped_captures = false)
{ {
return factory(key); return factory(key);
@@ -412,7 +412,7 @@ namespace Tensorflow
public object Obj => _obj; public object Obj => _obj;
public IList<MySaveableObject> mySaveables=> _saveables; public IList<MySaveableObject> mySaveables=> _saveables;


public override IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors()
public override IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors()
{ {
return saveable_object_util.saveable_object_to_tensor_dict(_saveables); return saveable_object_util.saveable_object_to_tensor_dict(_saveables);
} }
@@ -422,7 +422,7 @@ namespace Tensorflow
/// </summary> /// </summary>
/// <param name="restored_tensors"></param> /// <param name="restored_tensors"></param>
/// <returns></returns> /// <returns></returns>
public override IDictionary<string, Operation> _restore_from_tensors(IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors)
public override IDictionary<string, Operation> _restore_from_tensors(IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> restored_tensors)
{ {
List<string> expected_keys = new(); List<string> expected_keys = new();
foreach(var saveable in _saveables) foreach(var saveable in _saveables)


+ 9
- 8
src/TensorFlowNET.Core/Training/Trackable.cs View File

@@ -14,6 +14,7 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using OneOf;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics; using System.Diagnostics;
@@ -43,8 +44,8 @@ namespace Tensorflow.Train
protected IList<TrackableReference> _unconditional_checkpoint_dependencies; protected IList<TrackableReference> _unconditional_checkpoint_dependencies;
protected Dictionary<string, IList<CheckpointPosition>> _unconditional_deferred_dependencies; protected Dictionary<string, IList<CheckpointPosition>> _unconditional_deferred_dependencies;


protected IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> _self_saveable_object_factories =
new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>();
protected IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> _self_saveable_object_factories =
new Dictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>>();
private bool _manual_tracking = true; private bool _manual_tracking = true;


private static Trackable _none = new AutoTrackable(); private static Trackable _none = new AutoTrackable();
@@ -73,7 +74,7 @@ namespace Tensorflow.Train
public IDictionary<string, Trackable> UnconditionalDependencyNames { get => _unconditional_dependency_names; } public IDictionary<string, Trackable> UnconditionalDependencyNames { get => _unconditional_dependency_names; }
public IList<TrackableReference> CheckpointDependencies { get => UnconditionalCheckpointDependencies; } public IList<TrackableReference> CheckpointDependencies { get => UnconditionalCheckpointDependencies; }
public Dictionary<string, IList<CheckpointPosition>> DeferredDependencies => _unconditional_deferred_dependencies; public Dictionary<string, IList<CheckpointPosition>> DeferredDependencies => _unconditional_deferred_dependencies;
public IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> SelfSaveableObjectFactories
public IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> SelfSaveableObjectFactories
{ {
get get
{ {
@@ -249,9 +250,9 @@ namespace Tensorflow.Train
return self_tensor_map.Keys.ToList(); return self_tensor_map.Keys.ToList();
} }


public virtual IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint()
public virtual IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint()
{ {
Maybe<BaseResourceVariable, MySaveableObject> create_saveable(string name = "")
OneOf<BaseResourceVariable, MySaveableObject> create_saveable(string name = "")
{ {
throw new NotImplementedException(); throw new NotImplementedException();
//return new TrackableSaveable(this, null, name, null, null); //return new TrackableSaveable(this, null, name, null, null);
@@ -259,7 +260,7 @@ namespace Tensorflow.Train
if (saveable_object_util.trackable_has_serialize_to_tensor(this)) if (saveable_object_util.trackable_has_serialize_to_tensor(this))
{ {
// TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`). // TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`).
Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> res = new();
Dictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> res = new();
res[""] = create_saveable; res[""] = create_saveable;
return res; return res;
} }
@@ -278,12 +279,12 @@ namespace Tensorflow.Train
/// </summary> /// </summary>
/// <returns></returns> /// <returns></returns>
/// <exception cref="NotImplementedException"></exception> /// <exception cref="NotImplementedException"></exception>
public virtual IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors()
public virtual IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors()
{ {
throw new NotImplementedException(); throw new NotImplementedException();
} }


public virtual IDictionary<string, Operation> _restore_from_tensors(IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors)
public virtual IDictionary<string, Operation> _restore_from_tensors(IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> restored_tensors)
{ {
throw new NotImplementedException(); throw new NotImplementedException();
} }


+ 23
- 0
src/TensorFlowNET.Core/Util/function_utils.cs View File

@@ -0,0 +1,23 @@
using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Util
{
internal static class function_utils
{
private static string _rewriter_config_optimizer_disabled;
public static string get_disabled_rewriter_config()
{
if(_rewriter_config_optimizer_disabled is null)
{
var config = new ConfigProto();
var rewriter_config = config.GraphOptions.RewriteOptions;
rewriter_config.DisableMetaOptimizer = true;
_rewriter_config_optimizer_disabled = config.ToString();
}
return _rewriter_config_optimizer_disabled;
}
}
}

+ 4
- 3
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -8,6 +8,7 @@ using System.Collections.Generic;
using System.Diagnostics; using System.Diagnostics;
using Tensorflow.Checkpoint; using Tensorflow.Checkpoint;
using Tensorflow.Training.Saving.SavedModel; using Tensorflow.Training.Saving.SavedModel;
using OneOf;


namespace Tensorflow namespace Tensorflow
{ {
@@ -155,7 +156,7 @@ namespace Tensorflow
{ {
variable_accessed(this); variable_accessed(this);
var result = gen_resource_variable_ops.read_variable_op(handle, _dtype); var result = gen_resource_variable_ops.read_variable_op(handle, _dtype);
// _maybe_set_handle_data(_dtype, _handle, result);
resource_variable_ops._maybe_set_handle_data(_dtype, handle, result);


// have to set shape when converting to substituent placeholder // have to set shape when converting to substituent placeholder
if (result.shape.ndim == -1) if (result.shape.ndim == -1)
@@ -293,9 +294,9 @@ namespace Tensorflow
resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options); resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options);
} }


public override IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint()
public override IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint()
{ {
var res = new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>();
var res = new Dictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>>();
res[Trackable.Constants.VARIABLE_VALUE_KEY] = x => this; res[Trackable.Constants.VARIABLE_VALUE_KEY] = x => this;
return res; return res;
} }


+ 9
- 1
src/TensorFlowNET.Core/Variables/ResourceVariable.cs View File

@@ -124,7 +124,9 @@ namespace Tensorflow
initializer_op = gen_state_ops.assign(handle, _initial_value, true).op; initializer_op = gen_state_ops.assign(handle, _initial_value, true).op;


ops.colocate_with(initializer_op); ops.colocate_with(initializer_op);

tf.device(handle.Device);
var value = gen_resource_variable_ops.read_variable_op(handle, dtype);
resource_variable_ops._maybe_set_handle_data(dtype, handle, value);
_graph_element = gen_array_ops.identity(handle, name = "read"); _graph_element = gen_array_ops.identity(handle, name = "read");
ops.add_to_collections<IVariableV1>(collections, this); ops.add_to_collections<IVariableV1>(collections, this);
_dtype = handle.dtype; _dtype = handle.dtype;
@@ -141,6 +143,12 @@ namespace Tensorflow
gen_resource_variable_ops.assign_variable_op(handle, _initial_value); gen_resource_variable_ops.assign_variable_op(handle, _initial_value);
initializer_op = null; initializer_op = null;
_graph_element = null; _graph_element = null;
if (!string.IsNullOrEmpty(caching_device))
{
tf.device(caching_device);
var value = gen_resource_variable_ops.read_variable_op(handle, dtype);
resource_variable_ops._maybe_set_handle_data(dtype, handle, value);
}
_dtype = _initial_value.dtype.as_base_dtype(); _dtype = _initial_value.dtype.as_base_dtype();
// initial_value = _in_graph_mode ? initial_value : null; // initial_value = _in_graph_mode ? initial_value : null;
} }


+ 5
- 4
src/TensorFlowNET.Core/Variables/UninitializedVariable.cs View File

@@ -9,7 +9,7 @@ namespace Tensorflow.Variables
/// <summary> /// <summary>
/// A variable with no initializer. /// A variable with no initializer.
/// </summary> /// </summary>
public sealed class UninitializedVariable: BaseResourceVariable
public sealed class UninitializedVariable: BaseResourceVariable, IVariableV1
{ {
// TODO: complete the arg list. // TODO: complete the arg list.
public UninitializedVariable( public UninitializedVariable(
@@ -23,6 +23,7 @@ namespace Tensorflow.Variables
{ {
string unique_id = ""; string unique_id = "";
string handle_name = ""; string handle_name = "";
Tensor created_handle = null;
tf_with(ops.init_scope(), (x) => tf_with(ops.init_scope(), (x) =>
{ {
_in_graph_mode = !tf.Context.executing_eagerly(); _in_graph_mode = !tf.Context.executing_eagerly();
@@ -40,7 +41,7 @@ namespace Tensorflow.Variables
unique_id = $"{handle_name}-{ops.uid()}"; unique_id = $"{handle_name}-{ops.uid()}";
shared_name = null; shared_name = null;
} }
var handle = resource_variable_ops.variable_handle_from_shape_and_dtype(
created_handle = resource_variable_ops.variable_handle_from_shape_and_dtype(
shape, dtype, shared_name, name, _in_graph_mode, extra_handle_data); shape, dtype, shared_name, name, _in_graph_mode, extra_handle_data);
// skip the assignment of `handle._parent_trackable` because of lack of API. // skip the assignment of `handle._parent_trackable` because of lack of API.
// skip the assignment of `handle._name` and `handle._unique_id` because of accessability. // skip the assignment of `handle._name` and `handle._unique_id` because of accessability.
@@ -51,7 +52,7 @@ namespace Tensorflow.Variables
{ {
tf.device(handle.Device); tf.device(handle.Device);
var value = gen_resource_variable_ops.read_variable_op(handle, dtype); var value = gen_resource_variable_ops.read_variable_op(handle, dtype);
// _maybe_set_handle_data(dtype, handle, value)
resource_variable_ops._maybe_set_handle_data(dtype, handle, value);
_graph_element = value; _graph_element = value;
}); });
ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this); ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this);
@@ -64,7 +65,7 @@ namespace Tensorflow.Variables
}); });
_shape = shape; _shape = shape;
_dtype = dtype; _dtype = dtype;
base.__init__(trainable, handle, unique_id: unique_id, handle_name: handle_name);
base.__init__(trainable, created_handle, unique_id: unique_id, handle_name: handle_name);
} }
} }
} }

+ 6
- 2
src/TensorFlowNET.Core/ops.cs View File

@@ -26,6 +26,7 @@ using Tensorflow.Eager;
using Tensorflow.Graphs; using Tensorflow.Graphs;
using Tensorflow.Util; using Tensorflow.Util;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using static Tensorflow.CppShapeInferenceResult.Types;


namespace Tensorflow namespace Tensorflow
{ {
@@ -572,9 +573,12 @@ namespace Tensorflow
return get_default_graph().building_function; return get_default_graph().building_function;
} }


public static SafeTensorHandle get_resource_handle_data(Tensor graph_op)
public static HandleData get_resource_handle_data(Tensor graph_op)
{ {
throw new NotImplementedException();
// This implementation hasn't been checked for some reasons.
// If it throws an exception in the future, please check it.
var handle_data = c_api.GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output());
return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data)));
} }


public static void dismantle_graph(Graph graph) public static void dismantle_graph(Graph graph)


+ 2
- 2
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -40,7 +40,7 @@ namespace Tensorflow.Keras.Engine
/// <summary> /// <summary>
/// Arguments initialize layer. /// Arguments initialize layer.
/// </summary> /// </summary>
LayerArgs args;
internal LayerArgs args;


/// <summary> /// <summary>
/// Indicates whether `build` needs to be called upon layer call, to create /// Indicates whether `build` needs to be called upon layer call, to create
@@ -147,7 +147,7 @@ namespace Tensorflow.Keras.Engine
List<INode> outboundNodes; List<INode> outboundNodes;
public List<INode> OutboundNodes => outboundNodes; public List<INode> OutboundNodes => outboundNodes;


public JObject SerializedAttributes { get; set; }
public Dictionary<string, object> SerializedAttributes { get; set; }


ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>(); ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>();
public CallContext CallContext => callContext.Value; public CallContext CallContext => callContext.Value;


+ 58
- 21
src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs View File

@@ -26,7 +26,7 @@ namespace Tensorflow.Keras.Saving
{ {
public class KerasObjectLoader public class KerasObjectLoader
{ {
internal static readonly IDictionary<string, Trackable> PUBLIC_ATTRIBUTES = new CommonEndPoints().CheckpointableObjects;
internal static readonly IDictionary<string, Trackable> PUBLIC_ATTRIBUTES;
private SavedMetadata _metadata; private SavedMetadata _metadata;
private SavedObjectGraph _proto; private SavedObjectGraph _proto;
private Dictionary<int, string> _node_paths = new Dictionary<int, string>(); private Dictionary<int, string> _node_paths = new Dictionary<int, string>();
@@ -39,7 +39,13 @@ namespace Tensorflow.Keras.Saving


static KerasObjectLoader() static KerasObjectLoader()
{ {
PUBLIC_ATTRIBUTES[Keras.Saving.SavedModel.Constants.KERAS_ATTR] = null;
var endPoints = new CommonEndPoints();
PUBLIC_ATTRIBUTES = new Dictionary<string, Trackable>();
foreach (var key in endPoints._all_checkpointable_objects.Concat(endPoints._all_functions))
{
PUBLIC_ATTRIBUTES[key] = null;
}
PUBLIC_ATTRIBUTES[SavedModel.Constants.KERAS_ATTR] = null;
} }


public KerasObjectLoader(SavedMetadata metadata, SavedObjectGraph object_graph_def) public KerasObjectLoader(SavedMetadata metadata, SavedObjectGraph object_graph_def)
@@ -125,8 +131,14 @@ namespace Tensorflow.Keras.Saving
continue; continue;
} }


// TODO: deal with `RevivedLayer` and `RevivedInputLayer`.
layers_revived_from_config.Add(node as Layer);
if(node is RevivedLayer or RevivedInputLayer)
{
layers_revived_from_saved_model.Add(node as Layer);
}
else
{
layers_revived_from_config.Add(node as Layer);
}
} }


_finalize_saved_model_layers(layers_revived_from_saved_model); _finalize_saved_model_layers(layers_revived_from_saved_model);
@@ -171,10 +183,13 @@ namespace Tensorflow.Keras.Saving
// TODO(Rinne): implement it // TODO(Rinne): implement it
} }
} }
// `model.__init__(layers, config["name"])`
s.InitLayers(layers);
s.Name = config["name"].ToObject<string>();

// `model.__init__(layers, config["name"])`InitLayers(layers);
s = new Sequential(new SequentialArgs(){
Layers = layers.Select(x => x as ILayer).ToList(),
Name = config["name"].ToObject<string>()
});
//s.Name = config["name"].ToObject<string>();
if(s.input is null || s.input.Length == 0) if(s.input is null || s.input.Length == 0)
{ {
var first_layer = _get_child_layer_node_ids(model_id)[0]; var first_layer = _get_child_layer_node_ids(model_id)[0];
@@ -205,7 +220,12 @@ namespace Tensorflow.Keras.Saving


private void _set_network_attributes_from_metadata(Model revived_object) private void _set_network_attributes_from_metadata(Model revived_object)
{ {
// TODO: implement it.
var metadata = revived_object.SerializedAttributes["matadata"] as JObject;
if (metadata.ContainsKey("dtype"))
{
// TODO(Rinne): set_dtype_policy.
}
revived_object.args.Trainable = metadata["trainable"].Value<bool>();
} }


/// <summary> /// <summary>
@@ -330,7 +350,7 @@ namespace Tensorflow.Keras.Saving
private (Trackable, Action<object, object, object>) _revive_from_config(string identifier, KerasMetaData metadata, int node_id) private (Trackable, Action<object, object, object>) _revive_from_config(string identifier, KerasMetaData metadata, int node_id)
{ {
Trackable obj; Trackable obj;
if(identifier == Keras.Saving.SavedModel.Constants.METRIC_IDENTIFIER)
if(identifier == SavedModel.Constants.METRIC_IDENTIFIER)
{ {
// TODO(Rinne): implement it. // TODO(Rinne): implement it.
return (null, null); return (null, null);
@@ -429,25 +449,26 @@ namespace Tensorflow.Keras.Saving
return obj; return obj;
} }


private void _revive_setter(object layer, object name, object value)
private void _revive_setter(object obj, object name, object value)
{ {
Debug.Assert(name is string); Debug.Assert(name is string);
Debug.Assert(layer is Layer);
Debug.Assert(obj is Layer);
Layer layer = (Layer)obj;
if(PUBLIC_ATTRIBUTES.ContainsKey(name as string)) if(PUBLIC_ATTRIBUTES.ContainsKey(name as string))
{ {
if(value is Trackable) if(value is Trackable)
{ {
(layer as Layer)._track_trackable(value as Trackable, name as string);
layer._track_trackable(value as Trackable, name as string);
} }
if((layer as Layer).SerializedAttributes is null)
if(layer.SerializedAttributes is null)
{ {
(layer as Layer).SerializedAttributes = new JObject();
layer.SerializedAttributes = new Dictionary<string, object>();
} }
(layer as Layer).SerializedAttributes[name as string] = JToken.FromObject(value);
layer.SerializedAttributes[name as string] = value;
} }
else if(layer is Functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success)
else if(layer is Functional functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success)
{ {
(layer as Functional)._track_trackable(value as Trackable, name as string, overwrite: true);
functional._track_trackable(value as Trackable, name as string, overwrite: true);
} }
else else
{ {
@@ -521,7 +542,7 @@ namespace Tensorflow.Keras.Saving
} }


var metric_list_node_id = _search_for_child_node(node_id, new string[] { var metric_list_node_id = _search_for_child_node(node_id, new string[] {
Keras.Saving.SavedModel.Constants.KERAS_ATTR, "layer_metrics"
SavedModel.Constants.KERAS_ATTR, "layer_metrics"
}); });
if(metric_list_node_id is not null && obj is Model model && model.metrics is not null) if(metric_list_node_id is not null && obj is Model model && model.metrics is not null)
{ {
@@ -547,7 +568,7 @@ namespace Tensorflow.Keras.Saving
// skip the check for registered identifier // skip the check for registered identifier


Action<object, object, object> setter; Action<object, object, object> setter;
if (Keras.Saving.SavedModel.Constants.KERAS_OBJECT_IDENTIFIERS.Contains(obj_child.ObjectIdentifier))
if (SavedModel.Constants.KERAS_OBJECT_IDENTIFIERS.Contains(obj_child.ObjectIdentifier))
{ {
setter = _revive_setter; setter = _revive_setter;
} }
@@ -659,7 +680,23 @@ namespace Tensorflow.Keras.Saving


private void _maybe_add_serialized_attributes(Layer layer, KerasMetaData metadata) private void _maybe_add_serialized_attributes(Layer layer, KerasMetaData metadata)
{ {
// TODO: deal with `RevivedLayer`
if(layer.SerializedAttributes is null || layer.SerializedAttributes.Count == 0)
{
layer.SerializedAttributes = new Dictionary<string, object>();
layer.SerializedAttributes["metadata"] = metadata;
}
}

private static object _get_keras_attr(Layer layer)
{
if((layer.SerializedAttributes ?? new Dictionary<string, object>()).TryGetValue(SavedModel.Constants.KERAS_ATTR, out var value))
{
return value;
}
else
{
return null;
}
} }


/// <summary> /// <summary>


+ 9
- 4
src/TensorFlowNET.Keras/Saving/SavedModel/ReviveUtils.cs View File

@@ -24,17 +24,22 @@ namespace Tensorflow.Keras.Saving.SavedModel
} }
} }


public static void _revive_setter(object layer, object name, object value)
public static void _revive_setter(object obj, object name, object value)
{ {
Debug.Assert(name is string); Debug.Assert(name is string);
Debug.Assert(layer is Layer);
Debug.Assert(obj is Layer);
Layer layer = (Layer)obj;
if (KerasObjectLoader.PUBLIC_ATTRIBUTES.ContainsKey(name as string)) if (KerasObjectLoader.PUBLIC_ATTRIBUTES.ContainsKey(name as string))
{ {
if (value is Trackable trackable) if (value is Trackable trackable)
{ {
(layer as Layer)._track_trackable(trackable, name as string);
layer._track_trackable(trackable, name as string);
} }
(layer as Layer).SerializedAttributes[name] = JToken.FromObject(value);
if (layer.SerializedAttributes is null)
{
layer.SerializedAttributes = new Dictionary<string, object>();
}
layer.SerializedAttributes[name as string] = value;
} }
else if (layer is Functional functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success) else if (layer is Functional functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success)
{ {


+ 15
- 0
src/TensorFlowNET.Keras/Saving/SavedModel/RevivedInputLayer.cs View File

@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Keras.Saving.SavedModel
{
public class RevivedInputLayer: Layer
{
private RevivedInputLayer(): base(null)
{
throw new NotImplementedException();
}
}
}

+ 27
- 0
src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs View File

@@ -55,6 +55,21 @@ namespace Tensorflow.Keras.Saving.SavedModel


private RevivedConfig _config = null; private RevivedConfig _config = null;


public object keras_api
{
get
{
if (SerializedAttributes.TryGetValue(SavedModel.Constants.KERAS_ATTR, out var value))
{
return value;
}
else
{
return null;
}
}
}

public RevivedLayer(LayerArgs args): base(args) public RevivedLayer(LayerArgs args): base(args)
{ {


@@ -69,5 +84,17 @@ namespace Tensorflow.Keras.Saving.SavedModel
{ {
return _config; return _config;
} }

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
if(SerializedAttributes is null || !SerializedAttributes.TryGetValue("__call__", out var func) || func is not Function)
{
return base.Call(inputs, state, training);
}
else
{
return (func as Function).Apply(inputs);
}
}
} }
} }

+ 6
- 10
src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs View File

@@ -19,8 +19,8 @@ namespace Tensorflow.Keras.Saving.SavedModel
protected IDictionary<string, Trackable?> _object_dict; protected IDictionary<string, Trackable?> _object_dict;
protected IDictionary<string, Trackable?> _function_dict; protected IDictionary<string, Trackable?> _function_dict;
protected AutoTrackable _keras_trackable; protected AutoTrackable _keras_trackable;
protected HashSet<string> _all_functions;
protected HashSet<string> _all_checkpointable_objects;
internal HashSet<string> _all_functions;
internal HashSet<string> _all_checkpointable_objects;


private SerializedAttributes() private SerializedAttributes()
{ {
@@ -197,19 +197,15 @@ namespace Tensorflow.Keras.Saving.SavedModel
public class CommonEndPoints: SerializedAttributes public class CommonEndPoints: SerializedAttributes
{ {
public CommonEndPoints(IEnumerable<string> checkpointable_objects, IEnumerable<string> functions) : public CommonEndPoints(IEnumerable<string> checkpointable_objects, IEnumerable<string> functions) :
//base(checkpointable_objects.Concat(new string[] { "variables", "trainable_variables", "regularization_losses" }),
// functions.Concat(new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" }))
base(checkpointable_objects.Concat(new string[] { "variables", "trainable_variables"}),
functions.Concat(new string[] { }))
base(checkpointable_objects.Concat(new string[] { "variables", "trainable_variables", "regularization_losses" }),
functions.Concat(new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" }))
{ {


} }


public CommonEndPoints() : public CommonEndPoints() :
//base(new string[] { "variables", "trainable_variables", "regularization_losses" },
// new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" })
base(new string[] { "variables", "trainable_variables"},
new string[] {})
base(new string[] { "variables", "trainable_variables", "regularization_losses" },
new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" })
{ {


} }


Loading…
Cancel
Save