| @@ -1,7 +1,7 @@ | |||||
| | | ||||
| Microsoft Visual Studio Solution File, Format Version 12.00 | Microsoft Visual Studio Solution File, Format Version 12.00 | ||||
| # Visual Studio Version 16 | |||||
| VisualStudioVersion = 16.0.31624.102 | |||||
| # Visual Studio Version 17 | |||||
| VisualStudioVersion = 17.4.33213.308 | |||||
| MinimumVisualStudioVersion = 10.0.40219.1 | MinimumVisualStudioVersion = 10.0.40219.1 | ||||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Binding", "src\TensorFlowNET.Core\Tensorflow.Binding.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}" | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Binding", "src\TensorFlowNET.Core\Tensorflow.Binding.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}" | ||||
| EndProject | EndProject | ||||
| @@ -23,6 +23,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras.UnitTest", | |||||
| EndProject | EndProject | ||||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Graph.UnitTest", "test\TensorFlowNET.Graph.UnitTest\TensorFlowNET.Graph.UnitTest.csproj", "{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}" | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Graph.UnitTest", "test\TensorFlowNET.Graph.UnitTest\TensorFlowNET.Graph.UnitTest.csproj", "{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}" | ||||
| EndProject | EndProject | ||||
| Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Tensorflow.Common", "Tensorflow.Common\Tensorflow.Common.csproj", "{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}" | |||||
| EndProject | |||||
| Global | Global | ||||
| GlobalSection(SolutionConfigurationPlatforms) = preSolution | GlobalSection(SolutionConfigurationPlatforms) = preSolution | ||||
| Debug|Any CPU = Debug|Any CPU | Debug|Any CPU = Debug|Any CPU | ||||
| @@ -153,6 +155,18 @@ Global | |||||
| {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.Build.0 = Release|x64 | {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.Build.0 = Release|x64 | ||||
| {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.ActiveCfg = Release|Any CPU | {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.ActiveCfg = Release|Any CPU | ||||
| {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.Build.0 = Release|Any CPU | {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.Build.0 = Release|Any CPU | ||||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|x64.ActiveCfg = Debug|Any CPU | |||||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|x64.Build.0 = Debug|Any CPU | |||||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|x86.ActiveCfg = Debug|Any CPU | |||||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|x86.Build.0 = Debug|Any CPU | |||||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|Any CPU.Build.0 = Release|Any CPU | |||||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|x64.ActiveCfg = Release|Any CPU | |||||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|x64.Build.0 = Release|Any CPU | |||||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|x86.ActiveCfg = Release|Any CPU | |||||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|x86.Build.0 = Release|Any CPU | |||||
| EndGlobalSection | EndGlobalSection | ||||
| GlobalSection(SolutionProperties) = preSolution | GlobalSection(SolutionProperties) = preSolution | ||||
| HideSolutionNode = FALSE | HideSolutionNode = FALSE | ||||
| @@ -0,0 +1,13 @@ | |||||
| using OneOf; | |||||
| using System; | |||||
| namespace Tensorflow.Common.Extensions | |||||
| { | |||||
| public static class OneofExtension | |||||
| { | |||||
| public static bool IsTypeOrDeriveFrom<T>(this IOneOf src) | |||||
| { | |||||
| return src.Value is T; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,11 @@ | |||||
| <Project Sdk="Microsoft.NET.Sdk"> | |||||
| <PropertyGroup> | |||||
| <TargetFramework>netstandard2.0</TargetFramework> | |||||
| </PropertyGroup> | |||||
| <ItemGroup> | |||||
| <PackageReference Include="OneOf" Version="3.0.223" /> | |||||
| </ItemGroup> | |||||
| </Project> | |||||
| @@ -1,10 +1,12 @@ | |||||
| using System; | |||||
| using OneOf; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Diagnostics; | using System.Diagnostics; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Train; | using Tensorflow.Train; | ||||
| using Tensorflow.Training; | using Tensorflow.Training; | ||||
| using Tensorflow.Common.Extensions; | |||||
| using pbc = global::Google.Protobuf.Collections; | using pbc = global::Google.Protobuf.Collections; | ||||
| namespace Tensorflow.Checkpoint | namespace Tensorflow.Checkpoint | ||||
| @@ -28,7 +30,7 @@ namespace Tensorflow.Checkpoint | |||||
| ); | ); | ||||
| public static class SaveUtil | public static class SaveUtil | ||||
| { | { | ||||
| public static (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||||
| public static (IDictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||||
| serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null) | serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null) | ||||
| { | { | ||||
| var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); | var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); | ||||
| @@ -117,16 +119,16 @@ namespace Tensorflow.Checkpoint | |||||
| /// <param name="call_with_mapped_captures"></param> | /// <param name="call_with_mapped_captures"></param> | ||||
| /// <param name="cache"></param> | /// <param name="cache"></param> | ||||
| /// <param name="object_graph_proto"></param> | /// <param name="object_graph_proto"></param> | ||||
| private static IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids, | |||||
| private static IDictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids, | |||||
| bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) | bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) | ||||
| { | { | ||||
| Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new(); | |||||
| Dictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new(); | |||||
| foreach(var td in tensor_trackables) | foreach(var td in tensor_trackables) | ||||
| { | { | ||||
| // TODO: deal with cache. | // TODO: deal with cache. | ||||
| var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; | var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; | ||||
| Trackable trackable = null; | Trackable trackable = null; | ||||
| IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict; | |||||
| IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> tensor_dict; | |||||
| if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) | if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) | ||||
| { | { | ||||
| (trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); | (trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); | ||||
| @@ -148,12 +150,12 @@ namespace Tensorflow.Checkpoint | |||||
| return serialized_tensors; | return serialized_tensors; | ||||
| } | } | ||||
| private static IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||||
| private static IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||||
| { | { | ||||
| var trackable = trackable_data.object_to_save; | var trackable = trackable_data.object_to_save; | ||||
| // TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. | // TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. | ||||
| IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> ret_tensor_dict; | |||||
| IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> ret_tensor_dict; | |||||
| if (call_with_mapped_captures) | if (call_with_mapped_captures) | ||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| @@ -164,7 +166,7 @@ namespace Tensorflow.Checkpoint | |||||
| } | } | ||||
| // TODO: deal with the type `SaveSpce` (currently it will never be it). | // TODO: deal with the type `SaveSpce` (currently it will never be it). | ||||
| Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new(); | |||||
| Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> tensor_dict = new(); | |||||
| foreach(var pair in ret_tensor_dict) | foreach(var pair in ret_tensor_dict) | ||||
| { | { | ||||
| var local_name = TrackableUtils.escape_local_name(pair.Key); | var local_name = TrackableUtils.escape_local_name(pair.Key); | ||||
| @@ -200,7 +202,7 @@ namespace Tensorflow.Checkpoint | |||||
| /// <param name="call_with_mapped_captures"></param> | /// <param name="call_with_mapped_captures"></param> | ||||
| /// <param name="object_graph_proto"></param> | /// <param name="object_graph_proto"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| private static (Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids, | |||||
| private static (Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids, | |||||
| bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | ||||
| { | { | ||||
| Dictionary<Trackable, string> object_names = new(); | Dictionary<Trackable, string> object_names = new(); | ||||
| @@ -8,6 +8,7 @@ using Tensorflow.Training; | |||||
| using pbc = global::Google.Protobuf.Collections; | using pbc = global::Google.Protobuf.Collections; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using Google.Protobuf; | using Google.Protobuf; | ||||
| using OneOf; | |||||
| namespace Tensorflow.Checkpoint; | namespace Tensorflow.Checkpoint; | ||||
| @@ -179,13 +180,13 @@ public static class SaveUtilV1 | |||||
| // TODO: tensorflow python has a process with callable `saveable_factory`. | // TODO: tensorflow python has a process with callable `saveable_factory`. | ||||
| List<MySaveableObject> saveables = new(); | List<MySaveableObject> saveables = new(); | ||||
| if (maybe_saveable.TryGet<MySaveableObject>(out var s)) | |||||
| if (maybe_saveable.TryPickT1(out var s, out var variable)) | |||||
| { | { | ||||
| saveables.Add(s); | saveables.Add(s); | ||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValue<BaseResourceVariable>() as Trackable, key)); | |||||
| saveables.AddRange(saveable_object_util.saveable_objects_for_op(variable as Trackable, key)); | |||||
| } | } | ||||
| foreach (var saveable in saveables) | foreach (var saveable in saveables) | ||||
| @@ -217,7 +218,7 @@ public static class SaveUtilV1 | |||||
| public record class CheckpointFactoryData | public record class CheckpointFactoryData | ||||
| ( | ( | ||||
| Func<string, Maybe<BaseResourceVariable, MySaveableObject>> factory, | |||||
| Func<string, OneOf<BaseResourceVariable, MySaveableObject>> factory, | |||||
| string name, | string name, | ||||
| string checkpoint_key | string checkpoint_key | ||||
| ); | ); | ||||
| @@ -12,6 +12,7 @@ using static Tensorflow.Binding; | |||||
| using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
| using Newtonsoft.Json; | using Newtonsoft.Json; | ||||
| using Tensorflow.Training; | using Tensorflow.Training; | ||||
| using OneOf; | |||||
| namespace Tensorflow.Checkpoint; | namespace Tensorflow.Checkpoint; | ||||
| @@ -49,7 +50,7 @@ public class TrackableSaver | |||||
| } | } | ||||
| private (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||||
| private (IDictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||||
| gather_serialized_tensors(Tensor? object_graph_tensor = null) | gather_serialized_tensors(Tensor? object_graph_tensor = null) | ||||
| { | { | ||||
| var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); | var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); | ||||
| @@ -68,7 +69,7 @@ public class TrackableSaver | |||||
| Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | ||||
| if (!serialized_tensors.ContainsKey(Trackable.None)) | if (!serialized_tensors.ContainsKey(Trackable.None)) | ||||
| { | { | ||||
| serialized_tensors[Trackable.None] = new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>(); | |||||
| serialized_tensors[Trackable.None] = new Dictionary<string, OneOf.OneOf<Tensor, IDictionary<string, Tensor>>>(); | |||||
| } | } | ||||
| serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor; | serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor; | ||||
| return (serialized_tensors, feed_additions, registered_savers, graph_proto); | return (serialized_tensors, feed_additions, registered_savers, graph_proto); | ||||
| @@ -400,7 +401,7 @@ public class CheckpointRestoreCoordinator | |||||
| // skip the callback. | // skip the callback. | ||||
| } | } | ||||
| public List<Operation> restore_saveables(Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> tensor_saveables, List<CheckpointPosition> positions, object? registered_savers = null) | |||||
| public List<Operation> restore_saveables(Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> tensor_saveables, List<CheckpointPosition> positions, object? registered_savers = null) | |||||
| { | { | ||||
| List<Operation> restore_ops = new(); | List<Operation> restore_ops = new(); | ||||
| foreach(var position in positions) | foreach(var position in positions) | ||||
| @@ -412,7 +413,7 @@ public class CheckpointRestoreCoordinator | |||||
| Dictionary<string, BaseResourceVariable> variable_dict = new(); | Dictionary<string, BaseResourceVariable> variable_dict = new(); | ||||
| foreach(var item in tensor_saveables) | foreach(var item in tensor_saveables) | ||||
| { | { | ||||
| if(item.Value.TryGet<BaseResourceVariable>(out var variable)) | |||||
| if(item.Value.TryPickT0(out var variable, out var _)) | |||||
| { | { | ||||
| variable_dict[item.Key] = variable; | variable_dict[item.Key] = variable; | ||||
| } | } | ||||
| @@ -15,106 +15,14 @@ using Tensorflow.Graphs; | |||||
| using System.Xml.Linq; | using System.Xml.Linq; | ||||
| using System.Diagnostics; | using System.Diagnostics; | ||||
| using RestoreFunc = System.Func<object, object>; | using RestoreFunc = System.Func<object, object>; | ||||
| using OneOf; | |||||
| namespace Tensorflow.Checkpoint | namespace Tensorflow.Checkpoint | ||||
| { | { | ||||
| public class Maybe<TA, TB> | |||||
| { | |||||
| private TA? _valueA = default(TA); | |||||
| private TB? _valueB = default(TB); | |||||
| private Type _type; | |||||
| private bool _assignedTA; | |||||
| public Maybe(TA value) | |||||
| { | |||||
| _valueA = value; | |||||
| _type= typeof(TA); | |||||
| _assignedTA = true; | |||||
| } | |||||
| public Maybe(TB value) | |||||
| { | |||||
| _valueB = value; | |||||
| _type = typeof(TB); | |||||
| _assignedTA = false; | |||||
| } | |||||
| public Type DataType => _type; | |||||
| /// <summary> | |||||
| /// Try to get the type T member of this instance. It returns true when TA or TB derive from T and is correspondingly assigned. | |||||
| /// It returns | |||||
| /// </summary> | |||||
| /// <typeparam name="T"></typeparam> | |||||
| /// <param name="res"></param> | |||||
| /// <returns></returns> | |||||
| public bool TryGet<T>(out T? res) | |||||
| { | |||||
| if(_valueA is T && _valueB is not T) | |||||
| { | |||||
| res = (T)(object)_valueA; | |||||
| return _assignedTA; | |||||
| } | |||||
| else if(_valueA is not T && _valueB is T) | |||||
| { | |||||
| res = (T)(object)_valueB; | |||||
| return !_assignedTA; | |||||
| } | |||||
| res = default(T); | |||||
| return false; | |||||
| } | |||||
| public bool IsTypeOrDeriveFrom<T>() | |||||
| { | |||||
| if (_valueA is T && _valueB is not T) | |||||
| { | |||||
| return _assignedTA; | |||||
| } | |||||
| else if (_valueA is not T && _valueB is T) | |||||
| { | |||||
| return !_assignedTA; | |||||
| } | |||||
| else if (_valueA is T && _valueB is T) | |||||
| { | |||||
| return true; | |||||
| } | |||||
| else | |||||
| { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| public T GetValue<T>() | |||||
| { | |||||
| if (_valueA is T && _valueB is not T) | |||||
| { | |||||
| return (T)(object)_valueA; | |||||
| } | |||||
| else if (_valueA is not T && _valueB is T) | |||||
| { | |||||
| return (T)(object)_valueB; | |||||
| } | |||||
| else if (_valueA is T && _valueB is T) | |||||
| { | |||||
| throw new TypeError("The type is vague, this is always because TA and TB both derive from T."); | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new TypeError($"Expected {typeof(TA)} or {typeof(TB)}, but got typeof{typeof(T)}."); | |||||
| } | |||||
| } | |||||
| public static implicit operator Maybe<TA, TB>(TA a) | |||||
| { | |||||
| return new Maybe<TA, TB>(a); | |||||
| } | |||||
| public static implicit operator Maybe<TA, TB>(TB b) | |||||
| { | |||||
| return new Maybe<TA, TB>(b); | |||||
| } | |||||
| } | |||||
| internal class SingleDeviceSaver | internal class SingleDeviceSaver | ||||
| { | { | ||||
| private IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> _tensor_slice_dict; | |||||
| public SingleDeviceSaver(IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> tensor_slice_dict) | |||||
| private IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> _tensor_slice_dict; | |||||
| public SingleDeviceSaver(IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_slice_dict) | |||||
| { | { | ||||
| _tensor_slice_dict = tensor_slice_dict; | _tensor_slice_dict = tensor_slice_dict; | ||||
| } | } | ||||
| @@ -122,15 +30,15 @@ namespace Tensorflow.Checkpoint | |||||
| { | { | ||||
| _tensor_slice_dict = tensor_slice_dict.ToDictionary( | _tensor_slice_dict = tensor_slice_dict.ToDictionary( | ||||
| x => x.Key, x => x.Value.ToDictionary( | x => x.Key, x => x.Value.ToDictionary( | ||||
| y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value)) | |||||
| as IDictionary<string, Maybe<Tensor, SaveSpec>>); | |||||
| y => y.Key, y => OneOf<Tensor, SaveSpec>.FromT0(y.Value)) | |||||
| as IDictionary<string, OneOf<Tensor, SaveSpec>>); | |||||
| } | } | ||||
| public SingleDeviceSaver(IDictionary<string, IDictionary<string, SaveSpec>> tensor_slice_dict) | public SingleDeviceSaver(IDictionary<string, IDictionary<string, SaveSpec>> tensor_slice_dict) | ||||
| { | { | ||||
| _tensor_slice_dict = tensor_slice_dict.ToDictionary( | _tensor_slice_dict = tensor_slice_dict.ToDictionary( | ||||
| x => x.Key, x => x.Value.ToDictionary( | x => x.Key, x => x.Value.ToDictionary( | ||||
| y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value)) | |||||
| as IDictionary<string, Maybe<Tensor, SaveSpec>>); | |||||
| y => y.Key, y => OneOf<Tensor, SaveSpec>.FromT1(y.Value)) | |||||
| as IDictionary<string, OneOf<Tensor, SaveSpec>>); | |||||
| } | } | ||||
| public Operation? save(Tensor file_prefix, CheckpointOptions? options = null) | public Operation? save(Tensor file_prefix, CheckpointOptions? options = null) | ||||
| { | { | ||||
| @@ -149,7 +57,7 @@ namespace Tensorflow.Checkpoint | |||||
| { | { | ||||
| var slice_spec = slice.Key; | var slice_spec = slice.Key; | ||||
| var maybe_tensor = slice.Value; | var maybe_tensor = slice.Value; | ||||
| if(maybe_tensor.TryGet<SaveSpec>(out var spec)) | |||||
| if(maybe_tensor.TryPickT1(out var spec, out var tensor)) | |||||
| { | { | ||||
| var tensor_value = spec.tensor; | var tensor_value = spec.tensor; | ||||
| if (tensor_value is not null) | if (tensor_value is not null) | ||||
| @@ -161,7 +69,6 @@ namespace Tensorflow.Checkpoint | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| var tensor = maybe_tensor.GetValue<Tensor>(); | |||||
| tensor_names.Add(checkpoint_key); | tensor_names.Add(checkpoint_key); | ||||
| tensors.Add(tensor); | tensors.Add(tensor); | ||||
| slice_specs.Add(slice_spec); | slice_specs.Add(slice_spec); | ||||
| @@ -193,7 +100,7 @@ namespace Tensorflow.Checkpoint | |||||
| var slice_spec = slice.Key; | var slice_spec = slice.Key; | ||||
| var maybe_tensor = slice.Value; | var maybe_tensor = slice.Value; | ||||
| // TODO: deal with other types. Currently only `SaveSpec` is allowed. | // TODO: deal with other types. Currently only `SaveSpec` is allowed. | ||||
| if(maybe_tensor.TryGet<SaveSpec>(out var spec)) | |||||
| if(maybe_tensor.TryPickT1(out var spec, out var tensor)) | |||||
| { | { | ||||
| tensor_dtypes.Add(spec.dtype); | tensor_dtypes.Add(spec.dtype); | ||||
| slice_specs.Add(spec.slice_spec); | slice_specs.Add(spec.slice_spec); | ||||
| @@ -201,7 +108,6 @@ namespace Tensorflow.Checkpoint | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| var tensor = maybe_tensor.GetValue<Tensor>(); | |||||
| tensor_dtypes.Add(tensor.dtype); | tensor_dtypes.Add(tensor.dtype); | ||||
| slice_specs.Add(slice_spec); | slice_specs.Add(slice_spec); | ||||
| tensor_names.Add(checkpoint_key); | tensor_names.Add(checkpoint_key); | ||||
| @@ -254,7 +160,7 @@ namespace Tensorflow.Checkpoint | |||||
| /// <param name="serialized_tensors"> A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. </param> | /// <param name="serialized_tensors"> A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. </param> | ||||
| /// <param name="registered_savers"></param> | /// <param name="registered_savers"></param> | ||||
| /// <param name="call_with_mapped_capture"></param> | /// <param name="call_with_mapped_capture"></param> | ||||
| public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors, | |||||
| public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> serialized_tensors, | |||||
| IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false) | IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false) | ||||
| { | { | ||||
| _keys_to_restore_fn = new Dictionary<(string, string), RestoreFunc>(); | _keys_to_restore_fn = new Dictionary<(string, string), RestoreFunc>(); | ||||
| @@ -274,9 +180,9 @@ namespace Tensorflow.Checkpoint | |||||
| { | { | ||||
| restore_fn = new RestoreFunc(x => | restore_fn = new RestoreFunc(x => | ||||
| { | { | ||||
| if(x is IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>) | |||||
| if(x is IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>) | |||||
| { | { | ||||
| return obj._restore_from_tensors(x as IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>); | |||||
| return obj._restore_from_tensors(x as IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>); | |||||
| } | } | ||||
| throw new TypeError($"Expected `IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>` as input, got{x.GetType()}."); | throw new TypeError($"Expected `IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>` as input, got{x.GetType()}."); | ||||
| }); | }); | ||||
| @@ -286,14 +192,14 @@ namespace Tensorflow.Checkpoint | |||||
| { | { | ||||
| var checkpoint_key = item.Key; | var checkpoint_key = item.Key; | ||||
| IDictionary<string, Tensor> spec_to_tensor; | IDictionary<string, Tensor> spec_to_tensor; | ||||
| if(item.Value.TryGet<Tensor>(out var t)) | |||||
| if(item.Value.TryPickT0(out var t, out var dic)) | |||||
| { | { | ||||
| spec_to_tensor = new Dictionary<string, Tensor>(); | spec_to_tensor = new Dictionary<string, Tensor>(); | ||||
| spec_to_tensor[""] = t; | spec_to_tensor[""] = t; | ||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| spec_to_tensor = item.Value.GetValue<IDictionary<string, Tensor>>(); | |||||
| spec_to_tensor = dic; | |||||
| } | } | ||||
| foreach(var spec in spec_to_tensor) | foreach(var spec in spec_to_tensor) | ||||
| @@ -399,7 +305,7 @@ namespace Tensorflow.Checkpoint | |||||
| IDictionary<string, Operation> restore_func() | IDictionary<string, Operation> restore_func() | ||||
| { | { | ||||
| Dictionary<RestoreFunc, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> restore_fn_inputs = new(); | |||||
| Dictionary<RestoreFunc, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> restore_fn_inputs = new(); | |||||
| Dictionary<RestoreFunc, int> restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count); | Dictionary<RestoreFunc, int> restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count); | ||||
| Dictionary<string, Operation> restore_ops = new(); | Dictionary<string, Operation> restore_ops = new(); | ||||
| @@ -419,29 +325,29 @@ namespace Tensorflow.Checkpoint | |||||
| var slice_spec = item.Key; | var slice_spec = item.Key; | ||||
| var tensor = item.Value; | var tensor = item.Value; | ||||
| var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; | var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; | ||||
| var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>()); | |||||
| var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>()); | |||||
| if (!string.IsNullOrEmpty(slice_spec)) | if (!string.IsNullOrEmpty(slice_spec)) | ||||
| { | { | ||||
| if (!internal_dict.ContainsKey(checkpoint_key)) | if (!internal_dict.ContainsKey(checkpoint_key)) | ||||
| { | { | ||||
| Dictionary<string, Tensor> dict = new(); | Dictionary<string, Tensor> dict = new(); | ||||
| dict[slice_spec] = tensor; | dict[slice_spec] = tensor; | ||||
| internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(dict); | |||||
| internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT1(dict); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| internal_dict[checkpoint_key].GetValue<IDictionary<string, Tensor>>()[slice_spec] = tensor; | |||||
| internal_dict[checkpoint_key].AsT1[slice_spec] = tensor; | |||||
| } | } | ||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(tensor); | |||||
| internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT0(tensor); | |||||
| } | } | ||||
| restore_fn_input_count[restore_fn]--; | restore_fn_input_count[restore_fn]--; | ||||
| if (restore_fn_input_count[restore_fn] == 0) | if (restore_fn_input_count[restore_fn] == 0) | ||||
| { | { | ||||
| Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors = new(); | |||||
| Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> restored_tensors = new(); | |||||
| foreach(var input in restore_fn_inputs[restore_fn]) | foreach(var input in restore_fn_inputs[restore_fn]) | ||||
| { | { | ||||
| restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; | restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; | ||||
| @@ -519,7 +425,7 @@ namespace Tensorflow.Checkpoint | |||||
| public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false) | public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false) | ||||
| { | { | ||||
| Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new(); | |||||
| Dictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new(); | |||||
| foreach (var saveable in saveables) | foreach (var saveable in saveables) | ||||
| { | { | ||||
| var trackable = new SaveableCompatibilityConverter(saveable, new List<MySaveableObject>() { saveable }); | var trackable = new SaveableCompatibilityConverter(saveable, new List<MySaveableObject>() { saveable }); | ||||
| @@ -1,4 +1,5 @@ | |||||
| using System; | |||||
| using OneOf; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Diagnostics; | using System.Diagnostics; | ||||
| using System.Linq; | using System.Linq; | ||||
| @@ -61,13 +62,13 @@ public class CheckpointPosition | |||||
| } | } | ||||
| } | } | ||||
| public (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) gather_ops_or_named_saveables() | |||||
| public (List<Operation>, Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) gather_ops_or_named_saveables() | |||||
| { | { | ||||
| // skip the registered_saver | // skip the registered_saver | ||||
| if (ObjectProto.Attributes is null || ObjectProto.Attributes.Count == 0) | if (ObjectProto.Attributes is null || ObjectProto.Attributes.Count == 0) | ||||
| { | { | ||||
| return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(), | |||||
| return (new List<Operation>(), new Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>(), | |||||
| new List<CheckpointPosition>(), null); | new List<CheckpointPosition>(), null); | ||||
| } | } | ||||
| @@ -75,7 +76,7 @@ public class CheckpointPosition | |||||
| List<Operation> existing_restore_ops; | List<Operation> existing_restore_ops; | ||||
| List<CheckpointPosition> positions = new(); | List<CheckpointPosition> positions = new(); | ||||
| Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> named_saveables; | |||||
| Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> named_saveables; | |||||
| if (saveable_factories.Keys.Count == 1 && saveable_factories.Keys.First() == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) | if (saveable_factories.Keys.Count == 1 && saveable_factories.Keys.First() == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) | ||||
| { | { | ||||
| (existing_restore_ops, named_saveables) = _create_serialize_to_tensor_saveable(saveable_factories); | (existing_restore_ops, named_saveables) = _create_serialize_to_tensor_saveable(saveable_factories); | ||||
| @@ -109,8 +110,8 @@ public class CheckpointPosition | |||||
| /// Creates a saveable using the _serialize_to_tensor method. | /// Creates a saveable using the _serialize_to_tensor method. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="saveable_factories"></param> | /// <param name="saveable_factories"></param> | ||||
| private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>) _create_serialize_to_tensor_saveable( | |||||
| IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||||
| private (List<Operation>, Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>) _create_serialize_to_tensor_saveable( | |||||
| IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||||
| { | { | ||||
| string suffix = SaveableCompat.get_saveable_name(this.Trackable); | string suffix = SaveableCompat.get_saveable_name(this.Trackable); | ||||
| suffix = suffix ?? ""; | suffix = suffix ?? ""; | ||||
| @@ -124,23 +125,23 @@ public class CheckpointPosition | |||||
| var saveable = saveable_factories[TrackableUtils.SERIALIZE_TO_TENSORS_NAME](saveable_name); | var saveable = saveable_factories[TrackableUtils.SERIALIZE_TO_TENSORS_NAME](saveable_name); | ||||
| // skip the cache. | // skip the cache. | ||||
| Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> dict = new(); | |||||
| Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> dict = new(); | |||||
| dict[saveable_name] = saveable; | dict[saveable_name] = saveable; | ||||
| return (new List<Operation>(), dict); | return (new List<Operation>(), dict); | ||||
| } | } | ||||
| private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>) _create_saveables_by_attribute_name( | |||||
| IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||||
| private (List<Operation>, Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>) _create_saveables_by_attribute_name( | |||||
| IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||||
| { | { | ||||
| // TODO(Rinne): implement it. | // TODO(Rinne): implement it. | ||||
| if(ObjectProto.Attributes is null) | if(ObjectProto.Attributes is null) | ||||
| { | { | ||||
| return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>()); | |||||
| return (new List<Operation>(), new Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>()); | |||||
| } | } | ||||
| List<Operation> existing_restore_ops = new(); | List<Operation> existing_restore_ops = new(); | ||||
| HashSet<string> created_compat_names = new(); | HashSet<string> created_compat_names = new(); | ||||
| Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> named_saveables = new(); | |||||
| Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> named_saveables = new(); | |||||
| foreach (var serialized_tensor in ObjectProto.Attributes) | foreach (var serialized_tensor in ObjectProto.Attributes) | ||||
| { | { | ||||
| Operation existing_op; | Operation existing_op; | ||||
| @@ -172,12 +173,12 @@ public class CheckpointPosition | |||||
| _checkpoint.UnusedAttributes.SetDefault(_proto_id, new List<string>()).Add(serialized_tensor.Name); | _checkpoint.UnusedAttributes.SetDefault(_proto_id, new List<string>()).Add(serialized_tensor.Name); | ||||
| continue; | continue; | ||||
| } | } | ||||
| named_saveables[serialized_tensor.CheckpointKey] = saveable; | |||||
| named_saveables[serialized_tensor.CheckpointKey] = saveable.Value; | |||||
| } | } | ||||
| return (existing_restore_ops, named_saveables); | return (existing_restore_ops, named_saveables); | ||||
| } | } | ||||
| private Maybe<BaseResourceVariable, MySaveableObject> _get_saveable_from_factory(IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories, | |||||
| private OneOf<BaseResourceVariable, MySaveableObject>? _get_saveable_from_factory(IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> saveable_factories, | |||||
| TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor serialized_tensor, HashSet<string> created_compat_names) | TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor serialized_tensor, HashSet<string> created_compat_names) | ||||
| { | { | ||||
| var expected_factory_name = serialized_tensor.Name; | var expected_factory_name = serialized_tensor.Name; | ||||
| @@ -221,7 +222,7 @@ public class CheckpointPosition | |||||
| Queue<(CheckpointPosition, Trackable)> visit_queue = new(); | Queue<(CheckpointPosition, Trackable)> visit_queue = new(); | ||||
| visit_queue.Enqueue((this, this.Trackable)); | visit_queue.Enqueue((this, this.Trackable)); | ||||
| List<Operation> restore_ops = new(); | List<Operation> restore_ops = new(); | ||||
| Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> tensor_saveables = new(); | |||||
| Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> tensor_saveables = new(); | |||||
| List<CheckpointPosition> positions = new(); | List<CheckpointPosition> positions = new(); | ||||
| CheckpointPosition current_position = null; | CheckpointPosition current_position = null; | ||||
| @@ -306,7 +307,7 @@ public class CheckpointPosition | |||||
| } | } | ||||
| } | } | ||||
| private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) _single_restore() | |||||
| private (List<Operation>, Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) _single_restore() | |||||
| { | { | ||||
| var trackable = this.Trackable; | var trackable = this.Trackable; | ||||
| trackable._maybe_initialize_trackable(); | trackable._maybe_initialize_trackable(); | ||||
| @@ -318,7 +319,7 @@ public class CheckpointPosition | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(), | |||||
| return (new List<Operation>(), new Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>(), | |||||
| new List<CheckpointPosition>(), null); | new List<CheckpointPosition>(), null); | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,13 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Eager | |||||
| { | |||||
| public class TangentInfo | |||||
| { | |||||
| // TODO(Rinne): implement it. | |||||
| public object Indices { get; set; } | |||||
| public object Tangents { get; set; } | |||||
| } | |||||
| } | |||||
| @@ -1,6 +1,8 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Diagnostics; | |||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Eager; | |||||
| using Tensorflow.Framework.Models; | using Tensorflow.Framework.Models; | ||||
| using Tensorflow.Graphs; | using Tensorflow.Graphs; | ||||
| using Tensorflow.Train; | using Tensorflow.Train; | ||||
| @@ -17,11 +19,13 @@ namespace Tensorflow.Functions | |||||
| internal FuncGraph func_graph; | internal FuncGraph func_graph; | ||||
| protected DelayedRewriteGradientFunctions _delayed_rewrite_functions; | protected DelayedRewriteGradientFunctions _delayed_rewrite_functions; | ||||
| protected Dictionary<string, string> _attrs; | protected Dictionary<string, string> _attrs; | ||||
| protected FunctionSpec _function_spec; | |||||
| protected FunctionSpec _pre_initialized_function_spec = null; | |||||
| internal ForwardBackwardCall forward_backward; | internal ForwardBackwardCall forward_backward; | ||||
| public Tensor[] Inputs => func_graph.Inputs; | public Tensor[] Inputs => func_graph.Inputs; | ||||
| public Tensor[] CapturedInputs => func_graph.external_captures; | public Tensor[] CapturedInputs => func_graph.external_captures; | ||||
| public string Name => _delayed_rewrite_functions.forward().Name; | |||||
| public string Name => _delayed_rewrite_functions.Forward().Name; | |||||
| public Tensor[] Outputs; | public Tensor[] Outputs; | ||||
| public Type ReturnType; | public Type ReturnType; | ||||
| @@ -175,7 +179,13 @@ namespace Tensorflow.Functions | |||||
| var (forward_function, args_with_tangents) = forward_backward.Forward(); | var (forward_function, args_with_tangents) = forward_backward.Forward(); | ||||
| Tensors flat_outputs = null; | Tensors flat_outputs = null; | ||||
| if (executing_eagerly) | if (executing_eagerly) | ||||
| { | |||||
| flat_outputs = forward_function.Call(args_with_tangents); | |||||
| } | |||||
| else | |||||
| { | |||||
| flat_outputs = forward_function.Call(args_with_tangents); | flat_outputs = forward_function.Call(args_with_tangents); | ||||
| } | |||||
| forward_backward.Record(flat_outputs); | forward_backward.Record(flat_outputs); | ||||
| return flat_outputs; | return flat_outputs; | ||||
| } | } | ||||
| @@ -186,7 +196,7 @@ namespace Tensorflow.Functions | |||||
| { | { | ||||
| g = ops.get_default_graph(); | g = ops.get_default_graph(); | ||||
| } | } | ||||
| _delayed_rewrite_functions.forward().AddToGraph(g); | |||||
| _delayed_rewrite_functions.Forward().AddToGraph(g); | |||||
| } | } | ||||
| public void SetExternalCaptures(IEnumerable<Tensor> captures) | public void SetExternalCaptures(IEnumerable<Tensor> captures) | ||||
| @@ -196,8 +206,60 @@ namespace Tensorflow.Functions | |||||
| ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | ||||
| { | { | ||||
| var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | |||||
| return new ForwardBackwardCall(functions, args, tape_watching: true); | |||||
| TangentInfo input_tangents; | |||||
| if (executing_eagerly) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| else | |||||
| { | |||||
| input_tangents = new TangentInfo(); | |||||
| } | |||||
| if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER) | |||||
| { | |||||
| if(input_tangents.Indices is not null || executing_eagerly) | |||||
| { | |||||
| var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | |||||
| return new ForwardBackwardCall(functions, args, tape_watching: true); | |||||
| } | |||||
| else | |||||
| { | |||||
| return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: true); | |||||
| } | |||||
| } | |||||
| else if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| // TODO(Rinne): add arg "input_tagents" for ForwardBackwardCall. | |||||
| return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: false); | |||||
| } | |||||
| internal void _set_function_spec(FunctionSpec spec) | |||||
| { | |||||
| _function_spec = null; | |||||
| _pre_initialized_function_spec = spec; | |||||
| _initialize_function_spec(); | |||||
| } | |||||
| internal void _initialize_function_spec() | |||||
| { | |||||
| if(_pre_initialized_function_spec is null) | |||||
| { | |||||
| return; | |||||
| } | |||||
| Debug.Assert(_function_spec is null, "already initialized"); | |||||
| var spec = _pre_initialized_function_spec; | |||||
| //var args = spec.Fullargspec.DictValue.Fields["args"]; | |||||
| // TODO(Rinne): self.structured_input_signature | |||||
| _function_spec = new FunctionSpec() | |||||
| { | |||||
| Fullargspec = spec.Fullargspec, | |||||
| IsMethod = spec.IsMethod, | |||||
| InputSignature = spec.InputSignature | |||||
| }; | |||||
| } | } | ||||
| public override string ToString() | public override string ToString() | ||||
| @@ -5,6 +5,8 @@ using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Contexts; | using Tensorflow.Contexts; | ||||
| using Tensorflow.Graphs; | using Tensorflow.Graphs; | ||||
| using Tensorflow.Operations; | |||||
| using Tensorflow.Util; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Functions | namespace Tensorflow.Functions | ||||
| @@ -14,7 +16,10 @@ namespace Tensorflow.Functions | |||||
| public int _num_outputs; | public int _num_outputs; | ||||
| FuncGraph _func_graph; | FuncGraph _func_graph; | ||||
| FunctionDef _definition; | FunctionDef _definition; | ||||
| Tensor[] _func_graph_outputs; | |||||
| public string Name => _func_graph.FuncName; | public string Name => _func_graph.FuncName; | ||||
| public DataType[] OutputTypes { get; protected set; } | |||||
| public Shape[] OutputShapes { get; protected set; } | |||||
| public FunctionDef Definition | public FunctionDef Definition | ||||
| { | { | ||||
| get | get | ||||
| @@ -36,27 +41,69 @@ namespace Tensorflow.Functions | |||||
| var operations = graph.get_operations().Where(x => !input_ops.Contains(x.op)) | var operations = graph.get_operations().Where(x => !input_ops.Contains(x.op)) | ||||
| .Select(x => x as Operation).ToArray(); | .Select(x => x as Operation).ToArray(); | ||||
| var output_names = new string[0]; | var output_names = new string[0]; | ||||
| OutputShapes = outputs.Select(x => x.shape).ToArray(); | |||||
| OutputTypes = outputs.Select(x => x.dtype.as_datatype_enum()).ToArray(); | |||||
| _func_graph = new FuncGraph(graph, name, attrs); | _func_graph = new FuncGraph(graph, name, attrs); | ||||
| _func_graph_outputs = new List<Tensor>(outputs).ToArray(); | |||||
| _func_graph.ToGraph(operations, inputs, outputs, output_names); | _func_graph.ToGraph(operations, inputs, outputs, output_names); | ||||
| } | } | ||||
| public Tensors Call(Tensors args) | public Tensors Call(Tensors args) | ||||
| { | { | ||||
| // TODO(Rinne): Add arg `CancellationManager`. | |||||
| // TODO(Rinne): Check the arg length. | |||||
| var function_call_options = tf.Context.FunctionCallOptions; | |||||
| string config; | |||||
| if (string.IsNullOrEmpty(function_call_options.config_proto_serialized())) | |||||
| { | |||||
| config = function_utils.get_disabled_rewriter_config(); | |||||
| } | |||||
| else | |||||
| { | |||||
| config = function_call_options.config_proto_serialized(); | |||||
| } | |||||
| // TODO(Rinne): executor_type | |||||
| var executing_eagerly = tf.Context.executing_eagerly(); | |||||
| var attrs = new object[] | var attrs = new object[] | ||||
| { | { | ||||
| "executor_type", "", | "executor_type", "", | ||||
| "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() | "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() | ||||
| }; | }; | ||||
| var results = tf.Runner.TFE_Execute(tf.Context, | |||||
| Tensor[] outputs; | |||||
| if (executing_eagerly) | |||||
| { | |||||
| outputs = tf.Runner.TFE_Execute(tf.Context, | |||||
| tf.Context.DeviceName, | tf.Context.DeviceName, | ||||
| _func_graph.FuncName, | _func_graph.FuncName, | ||||
| args, | args, | ||||
| attrs, | attrs, | ||||
| _num_outputs); | _num_outputs); | ||||
| return results; | |||||
| } | |||||
| else | |||||
| { | |||||
| tf.GradientTape().stop_recording(); | |||||
| outputs = functional_ops.partitioned_call(args, this, OutputTypes, | |||||
| executing_eagerly, config, ""); | |||||
| } | |||||
| foreach(var (i, func_graph_output) in enumerate(_func_graph_outputs)) | |||||
| { | |||||
| handle_data_util.copy_handle_data(func_graph_output, outputs[i]); | |||||
| } | |||||
| if (executing_eagerly) | |||||
| { | |||||
| return outputs; | |||||
| } | |||||
| else | |||||
| { | |||||
| foreach(var (i, shape) in enumerate(OutputShapes)) | |||||
| { | |||||
| outputs[i].shape = shape; | |||||
| } | |||||
| return outputs; | |||||
| } | |||||
| } | } | ||||
| public void AddToGraph(Graph g = null) | public void AddToGraph(Graph g = null) | ||||
| @@ -9,16 +9,46 @@ namespace Tensorflow | |||||
| #pragma warning disable CS0169 // The field 'Function._handle' is never used | #pragma warning disable CS0169 // The field 'Function._handle' is never used | ||||
| private IntPtr _handle; | private IntPtr _handle; | ||||
| #pragma warning restore CS0169 // The field 'Function._handle' is never used | #pragma warning restore CS0169 // The field 'Function._handle' is never used | ||||
| protected Func<Tensors, Tensors> _function; | |||||
| protected ConcreteFunction _concrete_variable_creation_fn; | |||||
| protected bool _auto_graph; | |||||
| public string Name { get; set; } | public string Name { get; set; } | ||||
| public Function() | |||||
| public Function(Func<Tensors, Tensors> function, | |||||
| string name, bool auto_graph = true) | |||||
| { | |||||
| _function = function; | |||||
| Name = name; | |||||
| _auto_graph = auto_graph; | |||||
| } | |||||
| public virtual Tensors Apply(Tensors inputs) | |||||
| { | { | ||||
| if (_run_functions_eagerly()) | |||||
| { | |||||
| return _function(inputs); | |||||
| } | |||||
| var result = _call(inputs); | |||||
| return result; | |||||
| } | } | ||||
| public Function(string name) | |||||
| protected virtual Tensors _call(Tensors inputs) | |||||
| { | { | ||||
| Name = name; | |||||
| _initialize(); | |||||
| return _concrete_variable_creation_fn.CallFlat(inputs, | |||||
| _concrete_variable_creation_fn.CapturedInputs); | |||||
| } | |||||
| protected virtual bool _run_functions_eagerly() | |||||
| { | |||||
| return false; | |||||
| } | |||||
| private void _initialize() | |||||
| { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -15,11 +15,11 @@ namespace Tensorflow.Functions | |||||
| /// </summary> | /// </summary> | ||||
| public abstract class TapeGradientFunctions | public abstract class TapeGradientFunctions | ||||
| { | { | ||||
| string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"; | |||||
| string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"; | |||||
| string _FORWARD_PREFIX = "__forward_"; | |||||
| string _BACKWARD_PREFIX = "__backward_"; | |||||
| string _INFERENCE_PREFIX = "__inference_"; | |||||
| protected string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"; | |||||
| protected string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"; | |||||
| protected string _FORWARD_PREFIX = "__forward_"; | |||||
| protected string _BACKWARD_PREFIX = "__backward_"; | |||||
| protected string _INFERENCE_PREFIX = "__inference_"; | |||||
| protected FuncGraph _func_graph; | protected FuncGraph _func_graph; | ||||
| protected EagerDefinedFunction _forward; | protected EagerDefinedFunction _forward; | ||||
| @@ -35,8 +35,9 @@ namespace Tensorflow.Functions | |||||
| _func_graph = func_graph; | _func_graph = func_graph; | ||||
| } | } | ||||
| public EagerDefinedFunction Forward(Tensors inference_args) | |||||
| public virtual EagerDefinedFunction Forward(Tensors inference_args, Tensors input_tangents = null) | |||||
| { | { | ||||
| // TODO(Rinne): add input_tangents arg. | |||||
| return ForwardAndBackwardFunctions(inference_args); | return ForwardAndBackwardFunctions(inference_args); | ||||
| } | } | ||||
| @@ -45,8 +46,9 @@ namespace Tensorflow.Functions | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="flat_outputs"></param> | /// <param name="flat_outputs"></param> | ||||
| /// <param name="inference_args"></param> | /// <param name="inference_args"></param> | ||||
| public void Record(Tensors flat_outputs, Tensors inference_args) | |||||
| public virtual void Record(Tensors flat_outputs, Tensors inference_args) | |||||
| { | { | ||||
| // TODO(Rinne): add arg `input_tagents`. | |||||
| var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs); | var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs); | ||||
| tf.Runner.RecordGradient(_forward.Name, inference_args, new object[0], to_record, | tf.Runner.RecordGradient(_forward.Name, inference_args, new object[0], to_record, | ||||
| getBackwardFunction: backward_function); | getBackwardFunction: backward_function); | ||||
| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
| using Tensorflow.Train; | using Tensorflow.Train; | ||||
| using Tensorflow.Variables; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Functions | namespace Tensorflow.Functions | ||||
| @@ -5,16 +5,13 @@ using Tensorflow.Graphs; | |||||
| namespace Tensorflow.Functions | namespace Tensorflow.Functions | ||||
| { | { | ||||
| public class DelayedRewriteGradientFunctions | |||||
| public class DelayedRewriteGradientFunctions: TapeGradientFunctions | |||||
| { | { | ||||
| static readonly string _INFERENCE_PREFIX = "__inference_"; | |||||
| static readonly string _BACKWARD_PREFIX = "__backward_"; | |||||
| static readonly string _FORWARD_PREFIX = "__forward_"; | |||||
| FuncGraph _func_graph; | |||||
| EagerDefinedFunction _inference_function; | EagerDefinedFunction _inference_function; | ||||
| Dictionary<string, string> _attrs; | Dictionary<string, string> _attrs; | ||||
| int _num_inference_outputs; | int _num_inference_outputs; | ||||
| public DelayedRewriteGradientFunctions(FuncGraph func_graph, Dictionary<string, string> attrs) | public DelayedRewriteGradientFunctions(FuncGraph func_graph, Dictionary<string, string> attrs) | ||||
| :base(func_graph, false) | |||||
| { | { | ||||
| _func_graph= func_graph; | _func_graph= func_graph; | ||||
| _inference_function = new EagerDefinedFunction(_inference_name(_func_graph.Name), | _inference_function = new EagerDefinedFunction(_inference_name(_func_graph.Name), | ||||
| @@ -23,7 +20,7 @@ namespace Tensorflow.Functions | |||||
| _num_inference_outputs = _func_graph.Outputs.Length; | _num_inference_outputs = _func_graph.Outputs.Length; | ||||
| } | } | ||||
| public EagerDefinedFunction forward(Tensors inference_args = null, Tensors input_tangents = null) | |||||
| public override EagerDefinedFunction Forward(Tensors inference_args = null, Tensors input_tangents = null) | |||||
| { | { | ||||
| if(input_tangents is not null) | if(input_tangents is not null) | ||||
| { | { | ||||
| @@ -33,7 +30,23 @@ namespace Tensorflow.Functions | |||||
| return _inference_function; | return _inference_function; | ||||
| } | } | ||||
| private static string _inference_name(string name) | |||||
| public override void Record(Tensors flat_outputs, Tensors inference_args) | |||||
| { | |||||
| // TODO(Rinne): implement it. | |||||
| throw new NotImplementedException(); | |||||
| base.Record(flat_outputs, inference_args); | |||||
| } | |||||
| //private (BackwardFunction, Tensors) _backward(Tensors outputs) | |||||
| //{ | |||||
| // Tensor[] backward_function(Tensor[] grads, long[] unneeded_gradients) | |||||
| // { | |||||
| // var call_op = outputs[0].op; | |||||
| // } | |||||
| //} | |||||
| private string _inference_name(string name) | |||||
| { | { | ||||
| return $"{_INFERENCE_PREFIX}{name}_{ops.uid()}"; | return $"{_INFERENCE_PREFIX}{name}_{ops.uid()}"; | ||||
| } | } | ||||
| @@ -25,6 +25,11 @@ namespace Tensorflow | |||||
| { | { | ||||
| public class gradients_util | public class gradients_util | ||||
| { | { | ||||
| // Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are | |||||
| // unfortunately too slow to use here. | |||||
| public static int POSSIBLE_GRADIENT_TYPES_NONE = 0; | |||||
| public static int POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1; | |||||
| public static int POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2; | |||||
| public static Tensor[] _GradientsHelper(Tensor[] ys, | public static Tensor[] _GradientsHelper(Tensor[] ys, | ||||
| Tensor[] xs, | Tensor[] xs, | ||||
| Tensor[] grad_ys = null, | Tensor[] grad_ys = null, | ||||
| @@ -129,6 +129,7 @@ namespace Tensorflow | |||||
| protected Graph outer_graph; | protected Graph outer_graph; | ||||
| public Graph OuterGraph => outer_graph; | public Graph OuterGraph => outer_graph; | ||||
| public Dictionary<string, EagerDefinedFunction> Functions => _functions; | public Dictionary<string, EagerDefinedFunction> Functions => _functions; | ||||
| public SafeGraphHandle c_graph => _handle; | |||||
| public Graph() | public Graph() | ||||
| { | { | ||||
| @@ -208,5 +208,9 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, SafeStatusHandle status); | public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, SafeStatusHandle status); | ||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern IntPtr GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data); | |||||
| } | } | ||||
| } | } | ||||
| @@ -14,10 +14,14 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using Google.Protobuf; | |||||
| using Google.Protobuf.WellKnownTypes; | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Framework; | using Tensorflow.Framework; | ||||
| using Tensorflow.Functions; | |||||
| using Tensorflow.Operations; | |||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -25,6 +29,72 @@ namespace Tensorflow | |||||
| { | { | ||||
| public class functional_ops | public class functional_ops | ||||
| { | { | ||||
| public static Tensor[] partitioned_call(Tensors args, EagerDefinedFunction f, DataType[] tout, | |||||
| bool executing_eagerly, string config, string executor_type) | |||||
| { | |||||
| if (tout is null) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| if (config is null) | |||||
| { | |||||
| config = function_utils.get_disabled_rewriter_config(); | |||||
| } | |||||
| if (executor_type is null) | |||||
| { | |||||
| executor_type = ""; | |||||
| } | |||||
| if (executing_eagerly) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| var converted_args = args.Select(x => ops.convert_to_tensor(x)).ToArray(); | |||||
| AttrValue tin_attr = new() | |||||
| { | |||||
| List = new AttrValue.Types.ListValue() | |||||
| }; | |||||
| tin_attr.List.Type.AddRange(args.Select(x => x.dtype.as_datatype_enum())); | |||||
| AttrValue tout_attr = new() | |||||
| { | |||||
| List = new AttrValue.Types.ListValue() | |||||
| }; | |||||
| tout_attr.List.Type.AddRange(tout); | |||||
| AttrValue func_attr = new() | |||||
| { | |||||
| Func = new NameAttrList() | |||||
| }; | |||||
| func_attr.Func.Name = f.Name; | |||||
| AttrValue executor_type_attr = new AttrValue() | |||||
| { | |||||
| S = tf.compat.as_bytes(executor_type) | |||||
| }; | |||||
| AttrValue config_proto = new AttrValue() | |||||
| { | |||||
| S = ByteString.CopyFromUtf8(executor_type) | |||||
| }; | |||||
| var graph = ops.get_default_graph(); | |||||
| f.AddToGraph(graph); | |||||
| // TODO(Rinne): complete it with `f.stateful` | |||||
| var op_name = "PartitionedCall"; | |||||
| string xla_compile_attr = "_XlaMustCompile"; | |||||
| Dictionary<string, AttrValue> op_attrs = new(); | |||||
| op_attrs["Tin"] = tin_attr; | |||||
| op_attrs["Tout"] = tout_attr; | |||||
| op_attrs["f"] = func_attr; | |||||
| op_attrs["config_proto"] = config_proto; | |||||
| op_attrs["executor_type"] = executor_type_attr; | |||||
| // TODO(Rinne): deal with `f.definition`. | |||||
| var op = graph.create_op(op_name, args, tout.Select(x => x.as_tf_dtype()).ToArray(), | |||||
| name: op_name, attrs: op_attrs); | |||||
| var outputs = op.outputs; | |||||
| // TODO(Rinne): deal with `f.graph`. | |||||
| return outputs; | |||||
| } | |||||
| public static Tensor scan( | public static Tensor scan( | ||||
| Func<Tensor, Tensor, Tensor> fn, | Func<Tensor, Tensor, Tensor> fn, | ||||
| Tensor elems, | Tensor elems, | ||||
| @@ -0,0 +1,83 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using System.Xml.Linq; | |||||
| using Tensorflow.Contexts; | |||||
| using Tensorflow.Eager; | |||||
| using Tensorflow.Functions; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Operations | |||||
| { | |||||
| public class gen_functional_ops | |||||
| { | |||||
| public static Tensor[] partitioned_call(Tensors args, TF_DataType[] tout, EagerDefinedFunction f, | |||||
| string config = "", string config_proto = "", string executor_type = "", string name = null) | |||||
| { | |||||
| var ctx = tf.Context; | |||||
| if (ctx.executing_eagerly()) | |||||
| { | |||||
| try | |||||
| { | |||||
| return tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("PartitionedCall", name, | |||||
| args, tout, f, config, config_proto, executor_type)); | |||||
| } | |||||
| catch (Exception) | |||||
| { | |||||
| } | |||||
| } | |||||
| if (config is null) | |||||
| { | |||||
| config = ""; | |||||
| } | |||||
| if (config_proto is null) | |||||
| { | |||||
| config_proto = ""; | |||||
| } | |||||
| if (executor_type is null) | |||||
| { | |||||
| executor_type = ""; | |||||
| } | |||||
| Dictionary<string, object> kwargs = new(); | |||||
| kwargs["args"] = args; | |||||
| kwargs["Tout"] = tout; | |||||
| kwargs["f"] = f; | |||||
| kwargs["config"] = config; | |||||
| kwargs["config_proto"] = config_proto; | |||||
| kwargs["executor_type"] = executor_type; | |||||
| var output = tf.OpDefLib._apply_op_helper("PartitionedCall", | |||||
| name, kwargs); | |||||
| var result = output.outputs; | |||||
| if (execute.must_record_gradient()) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| return result; | |||||
| } | |||||
| public static Tensor[] partitioned_call_eager_fallback(Tensors args, TF_DataType[] tout, EagerDefinedFunction f, | |||||
| string config, string config_proto, string executor_type, string name, Context ctx) | |||||
| { | |||||
| // TODO(Rinne): implement it. | |||||
| throw new NotImplementedException(); | |||||
| if(config is null) | |||||
| { | |||||
| config = ""; | |||||
| } | |||||
| if(config_proto is null) | |||||
| { | |||||
| config_proto = ""; | |||||
| } | |||||
| if(executor_type is null) | |||||
| { | |||||
| executor_type = ""; | |||||
| } | |||||
| object[] attrs = new object[] | |||||
| { | |||||
| }; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,7 +1,9 @@ | |||||
| using System; | |||||
| using Google.Protobuf; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using static Tensorflow.CppShapeInferenceResult.Types; | |||||
| namespace Tensorflow.Operations | namespace Tensorflow.Operations | ||||
| { | { | ||||
| @@ -11,18 +13,31 @@ namespace Tensorflow.Operations | |||||
| { | { | ||||
| if(target_t.dtype == dtypes.resource || target_t.dtype == dtypes.variant) | if(target_t.dtype == dtypes.resource || target_t.dtype == dtypes.variant) | ||||
| { | { | ||||
| SafeTensorHandle handle_data; | |||||
| HandleData handle_data; | |||||
| if(source_t is EagerTensor) | if(source_t is EagerTensor) | ||||
| { | { | ||||
| handle_data = source_t.Handle; | |||||
| handle_data = source_t.HandleData; | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| handle_data = ops.get_resource_handle_data(source_t); | handle_data = ops.get_resource_handle_data(source_t); | ||||
| } | } | ||||
| throw new NotImplementedException(); | |||||
| //if(handle_data is not null && handle_data.) | |||||
| if(handle_data is not null && handle_data.IsSet && handle_data.ShapeAndType is not null | |||||
| && handle_data.ShapeAndType.Count > 0) | |||||
| { | |||||
| set_handle_data(target_t, handle_data); | |||||
| } | |||||
| } | |||||
| } | |||||
| public static void set_handle_data(Tensor target_t, HandleData handle_data) | |||||
| { | |||||
| if(target_t is EagerTensor) | |||||
| { | |||||
| target_t.HandleData = handle_data; | |||||
| return; | |||||
| } | } | ||||
| c_api.SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), handle_data.ToByteArray()); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -39,7 +39,7 @@ namespace Tensorflow | |||||
| public static bool is_resource_variable(IVariableV1 var) | public static bool is_resource_variable(IVariableV1 var) | ||||
| { | { | ||||
| return var is ResourceVariable; | |||||
| return var is BaseResourceVariable; | |||||
| } | } | ||||
| public static bool is_resource_variable(Trackable var) | public static bool is_resource_variable(Trackable var) | ||||
| @@ -231,5 +231,21 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| public static void _maybe_set_handle_data(TF_DataType dtype, Tensor handle, Tensor tensor) | |||||
| { | |||||
| if(dtype == dtypes.variant) | |||||
| { | |||||
| var handle_data = get_eager_safe_handle_data(handle); | |||||
| if(handle_data.IsSet && handle_data.ShapeAndType.Count > 1) | |||||
| { | |||||
| tensor.HandleData = new HandleData() | |||||
| { | |||||
| IsSet = true | |||||
| }; | |||||
| tensor.HandleData.ShapeAndType.AddRange(handle_data.ShapeAndType.Skip(1)); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -479,7 +479,7 @@ namespace Tensorflow { | |||||
| /// </summary> | /// </summary> | ||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
| public pbc::RepeatedField<global::Tensorflow.CppShapeInferenceResult.Types.HandleShapeAndType> ShapeAndType { | public pbc::RepeatedField<global::Tensorflow.CppShapeInferenceResult.Types.HandleShapeAndType> ShapeAndType { | ||||
| get { return shapeAndType_; } | |||||
| get { return shapeAndType_; } | |||||
| } | } | ||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
| @@ -277,15 +277,15 @@ namespace Tensorflow { | |||||
| get { return Descriptor; } | get { return Descriptor; } | ||||
| } | } | ||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public SavedObject() { | |||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public SavedObject() { | |||||
| OnConstruction(); | OnConstruction(); | ||||
| } | } | ||||
| partial void OnConstruction(); | partial void OnConstruction(); | ||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public SavedObject(SavedObject other) : this() { | |||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public SavedObject(SavedObject other) : this() { | |||||
| children_ = other.children_.Clone(); | children_ = other.children_.Clone(); | ||||
| dependencies_ = other.dependencies_.Clone(); | dependencies_ = other.dependencies_.Clone(); | ||||
| slotVariables_ = other.slotVariables_.Clone(); | slotVariables_ = other.slotVariables_.Clone(); | ||||
| @@ -329,7 +329,9 @@ namespace Tensorflow { | |||||
| public const int ChildrenFieldNumber = 1; | public const int ChildrenFieldNumber = 1; | ||||
| private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_children_codec | private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_children_codec | ||||
| = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); | = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); | ||||
| private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); | |||||
| private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_dependencies_codec | |||||
| = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); | |||||
| private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); | |||||
| private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> dependencies_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); | private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> dependencies_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); | ||||
| /// <summary> | /// <summary> | ||||
| /// Objects which this object depends on: named edges in the dependency | /// Objects which this object depends on: named edges in the dependency | ||||
| @@ -501,7 +503,8 @@ namespace Tensorflow { | |||||
| return true; | return true; | ||||
| } | } | ||||
| if(!children_.Equals(other.children_)) return false; | if(!children_.Equals(other.children_)) return false; | ||||
| if(!slotVariables_.Equals(other.slotVariables_)) return false; | |||||
| if (!dependencies_.Equals(other.dependencies_)) return false; | |||||
| if (!slotVariables_.Equals(other.slotVariables_)) return false; | |||||
| if (!object.Equals(UserObject, other.UserObject)) return false; | if (!object.Equals(UserObject, other.UserObject)) return false; | ||||
| if (!object.Equals(Asset, other.Asset)) return false; | if (!object.Equals(Asset, other.Asset)) return false; | ||||
| if (!object.Equals(Function, other.Function)) return false; | if (!object.Equals(Function, other.Function)) return false; | ||||
| @@ -519,6 +522,7 @@ namespace Tensorflow { | |||||
| public override int GetHashCode() { | public override int GetHashCode() { | ||||
| int hash = 1; | int hash = 1; | ||||
| hash ^= children_.GetHashCode(); | hash ^= children_.GetHashCode(); | ||||
| hash ^= dependencies_.GetHashCode(); | |||||
| hash ^= slotVariables_.GetHashCode(); | hash ^= slotVariables_.GetHashCode(); | ||||
| if (kindCase_ == KindOneofCase.UserObject) hash ^= UserObject.GetHashCode(); | if (kindCase_ == KindOneofCase.UserObject) hash ^= UserObject.GetHashCode(); | ||||
| if (kindCase_ == KindOneofCase.Asset) hash ^= Asset.GetHashCode(); | if (kindCase_ == KindOneofCase.Asset) hash ^= Asset.GetHashCode(); | ||||
| @@ -544,6 +548,7 @@ namespace Tensorflow { | |||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
| public void WriteTo(pb::CodedOutputStream output) { | public void WriteTo(pb::CodedOutputStream output) { | ||||
| children_.WriteTo(output, _repeated_children_codec); | children_.WriteTo(output, _repeated_children_codec); | ||||
| children_.WriteTo(output, _repeated_dependencies_codec); | |||||
| slotVariables_.WriteTo(output, _repeated_slotVariables_codec); | slotVariables_.WriteTo(output, _repeated_slotVariables_codec); | ||||
| if (kindCase_ == KindOneofCase.UserObject) { | if (kindCase_ == KindOneofCase.UserObject) { | ||||
| output.WriteRawTag(34); | output.WriteRawTag(34); | ||||
| @@ -587,6 +592,7 @@ namespace Tensorflow { | |||||
| public int CalculateSize() { | public int CalculateSize() { | ||||
| int size = 0; | int size = 0; | ||||
| size += children_.CalculateSize(_repeated_children_codec); | size += children_.CalculateSize(_repeated_children_codec); | ||||
| size += children_.CalculateSize(_repeated_dependencies_codec); | |||||
| size += slotVariables_.CalculateSize(_repeated_slotVariables_codec); | size += slotVariables_.CalculateSize(_repeated_slotVariables_codec); | ||||
| if (kindCase_ == KindOneofCase.UserObject) { | if (kindCase_ == KindOneofCase.UserObject) { | ||||
| size += 1 + pb::CodedOutputStream.ComputeMessageSize(UserObject); | size += 1 + pb::CodedOutputStream.ComputeMessageSize(UserObject); | ||||
| @@ -619,7 +625,7 @@ namespace Tensorflow { | |||||
| return size; | return size; | ||||
| } | } | ||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| //[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public void MergeFrom(SavedObject other) { | public void MergeFrom(SavedObject other) { | ||||
| if (other == null) { | if (other == null) { | ||||
| return; | return; | ||||
| @@ -682,7 +688,7 @@ namespace Tensorflow { | |||||
| _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | ||||
| } | } | ||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| //[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public void MergeFrom(pb::CodedInputStream input) { | public void MergeFrom(pb::CodedInputStream input) { | ||||
| uint tag; | uint tag; | ||||
| while ((tag = input.ReadTag()) != 0) { | while ((tag = input.ReadTag()) != 0) { | ||||
| @@ -692,9 +698,10 @@ namespace Tensorflow { | |||||
| break; | break; | ||||
| case 10: { | case 10: { | ||||
| children_.AddEntriesFrom(input, _repeated_children_codec); | children_.AddEntriesFrom(input, _repeated_children_codec); | ||||
| dependencies_.AddRange(children_.Except(dependencies_)); | |||||
| break; | break; | ||||
| } | } | ||||
| case 26: { | |||||
| case 26: { | |||||
| slotVariables_.AddEntriesFrom(input, _repeated_slotVariables_codec); | slotVariables_.AddEntriesFrom(input, _repeated_slotVariables_codec); | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -109,7 +109,12 @@ https://tensorflownet.readthedocs.io</Description> | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> | <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> | ||||
| <PackageReference Include="Newtonsoft.Json" Version="13.0.2" /> | <PackageReference Include="Newtonsoft.Json" Version="13.0.2" /> | ||||
| <PackageReference Include="OneOf" Version="3.0.223" /> | |||||
| <PackageReference Include="Protobuf.Text" Version="0.6.1" /> | <PackageReference Include="Protobuf.Text" Version="0.6.1" /> | ||||
| <PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" /> | <PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | |||||
| <ProjectReference Include="..\..\Tensorflow.Common\Tensorflow.Common.csproj" /> | |||||
| </ItemGroup> | |||||
| </Project> | </Project> | ||||
| @@ -87,6 +87,7 @@ namespace Tensorflow | |||||
| public object Tag { get; set; } | public object Tag { get; set; } | ||||
| protected new SafeTensorHandle _handle; | protected new SafeTensorHandle _handle; | ||||
| public virtual SafeTensorHandle Handle => _handle; | public virtual SafeTensorHandle Handle => _handle; | ||||
| public Tensorflow.CppShapeInferenceResult.Types.HandleData HandleData { get; internal set; } | |||||
| protected SafeEagerTensorHandle _eagerTensorHandle; | protected SafeEagerTensorHandle _eagerTensorHandle; | ||||
| /// <summary> | /// <summary> | ||||
| @@ -14,18 +14,19 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using OneOf; | |||||
| using Tensorflow.Checkpoint; | using Tensorflow.Checkpoint; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class MySaveableObject | public class MySaveableObject | ||||
| { | { | ||||
| protected Maybe<Tensor, BaseResourceVariable> _op; | |||||
| protected OneOf<Tensor, BaseResourceVariable> _op; | |||||
| public Tensor op | public Tensor op | ||||
| { | { | ||||
| get | get | ||||
| { | { | ||||
| if(_op.TryGet<Tensor>(out var tensor)) | |||||
| if(_op.TryPickT0(out var tensor, out var _)) | |||||
| { | { | ||||
| return tensor; | return tensor; | ||||
| } | } | ||||
| @@ -43,7 +44,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| get | get | ||||
| { | { | ||||
| if (_op.TryGet<BaseResourceVariable>(out var v)) | |||||
| if (_op.TryPickT1(out var v, out var _)) | |||||
| { | { | ||||
| return v; | return v; | ||||
| } | } | ||||
| @@ -25,11 +25,32 @@ namespace Tensorflow.Training.Saving.SavedModel | |||||
| /// <param name="saved_concrete_function"></param> | /// <param name="saved_concrete_function"></param> | ||||
| /// <param name="concrete_functions"></param> | /// <param name="concrete_functions"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static ConcreteFunction recreate_function(SavedFunction saved_concrete_function, | |||||
| public static Function recreate_function(SavedFunction saved_function, | |||||
| IDictionary<string, ConcreteFunction> concrete_functions) | IDictionary<string, ConcreteFunction> concrete_functions) | ||||
| { | { | ||||
| var function_spec = _deserialize_function_spec_as_nonmethod(saved_concrete_function.FunctionSpec); | |||||
| return null; | |||||
| var function_spec = _deserialize_function_spec_as_nonmethod(saved_function.FunctionSpec); | |||||
| List<ConcreteFunction> concrete_function_objects = new(); | |||||
| foreach(var concrete_function_name in saved_function.ConcreteFunctions) | |||||
| { | |||||
| concrete_function_objects.Add(concrete_functions[concrete_function_name]); | |||||
| } | |||||
| foreach(var cf in concrete_function_objects) | |||||
| { | |||||
| cf._set_function_spec(function_spec); | |||||
| } | |||||
| foreach(var function_name in saved_function.ConcreteFunctions) | |||||
| { | |||||
| var function = concrete_functions[function_name]; | |||||
| if(_concrete_function_callable_with(function, null, false)) | |||||
| { | |||||
| return new RestoredFunction(null, function, "function_from_deserialization"); | |||||
| } | |||||
| } | |||||
| return new RestoredFunction(x => x, new ConcreteFunction(x => x, TF_DataType.TF_FLOAT), "function_return_itself"); | |||||
| //throw new ValueError("Unexpected runtime behavior, please submit an issue to " + | |||||
| // "https://github.com/SciSharp/TensorFlow.NET/issues"); | |||||
| } | } | ||||
| public static Dictionary<string, ConcreteFunction> load_function_def_library(FunctionDefLibrary library, | public static Dictionary<string, ConcreteFunction> load_function_def_library(FunctionDefLibrary library, | ||||
| @@ -385,5 +406,31 @@ namespace Tensorflow.Training.Saving.SavedModel | |||||
| JitCompile = function_spec_proto.JitCompile | JitCompile = function_spec_proto.JitCompile | ||||
| }; | }; | ||||
| } | } | ||||
| private static Tensors _call_concrete_function(ConcreteFunction function, Tensors inputs) | |||||
| { | |||||
| // TODO(Rinne): var expected_structure = function.func_graph.structured_input_signature | |||||
| return function.CallFlat(inputs, function.CapturedInputs); | |||||
| } | |||||
| private static bool _concrete_function_callable_with(ConcreteFunction function, Tensors inputs, bool allow_conversion) | |||||
| { | |||||
| // TODO(Rinne): revise it. | |||||
| return true; | |||||
| } | |||||
| } | |||||
| public class RestoredFunction : Function | |||||
| { | |||||
| public RestoredFunction(Func<Tensors, Tensors> function, ConcreteFunction concrete_function, | |||||
| string name, bool auto_graph = true): base(function, name, auto_graph) | |||||
| { | |||||
| _concrete_variable_creation_fn = concrete_function; | |||||
| } | |||||
| protected override bool _run_functions_eagerly() | |||||
| { | |||||
| return false; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -14,6 +14,7 @@ using Tensorflow.Variables; | |||||
| using Tensorflow.Functions; | using Tensorflow.Functions; | ||||
| using Tensorflow.Training.Saving.SavedModel; | using Tensorflow.Training.Saving.SavedModel; | ||||
| using Tensorflow.Trackables; | using Tensorflow.Trackables; | ||||
| using OneOf; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -44,6 +45,8 @@ namespace Tensorflow | |||||
| _asset_file_def = meta_graph.AssetFileDef; | _asset_file_def = meta_graph.AssetFileDef; | ||||
| _operation_attributes = meta_graph.GraphDef.Node.ToDictionary(x => x.Name, x => x.Attr); | _operation_attributes = meta_graph.GraphDef.Node.ToDictionary(x => x.Name, x => x.Attr); | ||||
| _proto = object_graph_proto; | _proto = object_graph_proto; | ||||
| // Debug(Rinne) | |||||
| var temp = _proto.ToString(); | |||||
| _export_dir = export_dir; | _export_dir = export_dir; | ||||
| // TODO: `this._concrete_functions` and `this._restored_concrete_functions` | // TODO: `this._concrete_functions` and `this._restored_concrete_functions` | ||||
| _concrete_functions = function_deserialization.load_function_def_library( | _concrete_functions = function_deserialization.load_function_def_library( | ||||
| @@ -259,9 +262,9 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="proto"></param> | /// <param name="proto"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| private Dictionary<Maybe<string, int>, int> _get_node_dependencies(SavedObject proto) | |||||
| private Dictionary<OneOf<string, int>, int> _get_node_dependencies(SavedObject proto) | |||||
| { | { | ||||
| Dictionary<Maybe<string, int>, int> dependencies = new(); | |||||
| Dictionary<OneOf<string, int>, int> dependencies = new(); | |||||
| foreach(var refer in proto.Dependencies) | foreach(var refer in proto.Dependencies) | ||||
| { | { | ||||
| dependencies[refer.LocalName] = refer.NodeId; | dependencies[refer.LocalName] = refer.NodeId; | ||||
| @@ -375,11 +378,6 @@ namespace Tensorflow | |||||
| // Re-create everything. | // Re-create everything. | ||||
| foreach (var (node_id, proto) in _iter_all_nodes()) | foreach (var (node_id, proto) in _iter_all_nodes()) | ||||
| { | { | ||||
| if(node_id == 45) | |||||
| { | |||||
| // TODelete | |||||
| Console.WriteLine(); | |||||
| } | |||||
| if (nodes.ContainsKey(node_id)) | if (nodes.ContainsKey(node_id)) | ||||
| { | { | ||||
| continue; | continue; | ||||
| @@ -474,7 +472,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| private void _setup_function_captures(string concrete_function_name, IDictionary<Maybe<string, int>, Trackable> nodes) | |||||
| private void _setup_function_captures(string concrete_function_name, IDictionary<OneOf<string, int>, Trackable> nodes) | |||||
| { | { | ||||
| if (_restored_concrete_functions.Contains(concrete_function_name)) | if (_restored_concrete_functions.Contains(concrete_function_name)) | ||||
| { | { | ||||
| @@ -509,6 +507,11 @@ namespace Tensorflow | |||||
| /// <param name="node_id"></param> | /// <param name="node_id"></param> | ||||
| private void _add_object_graph_edges(SavedObject proto, int node_id) | private void _add_object_graph_edges(SavedObject proto, int node_id) | ||||
| { | { | ||||
| // Debug(Rinne) | |||||
| if(node_id == 1) | |||||
| { | |||||
| Console.WriteLine(); | |||||
| } | |||||
| var obj = _nodes[node_id]; | var obj = _nodes[node_id]; | ||||
| var setter = _node_setters[node_id]; | var setter = _node_setters[node_id]; | ||||
| @@ -549,8 +552,13 @@ namespace Tensorflow | |||||
| private (Trackable, Action<object, object, object>) _recreate(SavedObject proto, int node_id, IDictionary<int, Trackable> nodes) | private (Trackable, Action<object, object, object>) _recreate(SavedObject proto, int node_id, IDictionary<int, Trackable> nodes) | ||||
| { | { | ||||
| // skip the registered classes. | // skip the registered classes. | ||||
| if(node_id == 16) | |||||
| { | |||||
| // Debug(Rinne) | |||||
| Console.WriteLine(); | |||||
| } | |||||
| Dictionary<Maybe<string, int>, Trackable> dependencies = new(); | |||||
| Dictionary<OneOf<string, int>, Trackable> dependencies = new(); | |||||
| foreach(var item in _get_node_dependencies(proto)) | foreach(var item in _get_node_dependencies(proto)) | ||||
| { | { | ||||
| dependencies[item.Key] = nodes[item.Value]; | dependencies[item.Key] = nodes[item.Value]; | ||||
| @@ -571,7 +579,7 @@ namespace Tensorflow | |||||
| /// <param name="proto"></param> | /// <param name="proto"></param> | ||||
| /// <param name="node_id"></param> | /// <param name="node_id"></param> | ||||
| /// <param name="dependencies"></param> | /// <param name="dependencies"></param> | ||||
| private (Trackable, Action<object, object, object>) _recreate_default(SavedObject proto, int node_id, IDictionary<Maybe<string, int>, Trackable> dependencies) | |||||
| private (Trackable, Action<object, object, object>) _recreate_default(SavedObject proto, int node_id, IDictionary<OneOf<string, int>, Trackable> dependencies) | |||||
| { | { | ||||
| return proto.KindCase switch | return proto.KindCase switch | ||||
| { | { | ||||
| @@ -637,10 +645,10 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| private (ConcreteFunction, Action<object, object, object>) _recreate_function(SavedFunction proto, | |||||
| Dictionary<Maybe<string, int>, Trackable> dependencies) | |||||
| private (Function, Action<object, object, object>) _recreate_function(SavedFunction proto, | |||||
| Dictionary<OneOf<string, int>, Trackable> dependencies) | |||||
| { | { | ||||
| var fn = function_deserialization.recreate_function(proto, null); | |||||
| var fn = function_deserialization.recreate_function(proto, _concrete_functions); | |||||
| foreach (var name in proto.ConcreteFunctions) | foreach (var name in proto.ConcreteFunctions) | ||||
| { | { | ||||
| _setup_function_captures(name, dependencies); | _setup_function_captures(name, dependencies); | ||||
| @@ -649,7 +657,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, | private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, | ||||
| IDictionary<Maybe<string, int>, Trackable> dependencies) | |||||
| IDictionary<OneOf<string, int>, Trackable> dependencies) | |||||
| { | { | ||||
| var fn = function_deserialization.setup_bare_concrete_function(proto, _concrete_functions); | var fn = function_deserialization.setup_bare_concrete_function(proto, _concrete_functions); | ||||
| _setup_function_captures(proto.ConcreteFunctionName, dependencies); | _setup_function_captures(proto.ConcreteFunctionName, dependencies); | ||||
| @@ -14,6 +14,7 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using OneOf; | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Diagnostics; | using System.Diagnostics; | ||||
| @@ -174,7 +175,7 @@ namespace Tensorflow | |||||
| full_name = name + "_" + attr; | full_name = name + "_" + attr; | ||||
| } | } | ||||
| var op = factory(full_name); | var op = factory(full_name); | ||||
| if(op.TryGet<BaseResourceVariable>(out var variable)) | |||||
| if(op.TryPickT0(out var variable, out var saveable)) | |||||
| { | { | ||||
| foreach (var v in saveable_objects_for_op(variable as Trackable, variable.Name)) | foreach (var v in saveable_objects_for_op(variable as Trackable, variable.Name)) | ||||
| { | { | ||||
| @@ -183,7 +184,6 @@ namespace Tensorflow | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| var saveable = op.GetValue<MySaveableObject>(); | |||||
| foreach (var v in saveable_objects_for_op(saveable, saveable.name)) | foreach (var v in saveable_objects_for_op(saveable, saveable.name)) | ||||
| { | { | ||||
| yield return v; | yield return v; | ||||
| @@ -252,11 +252,11 @@ namespace Tensorflow | |||||
| return names_to_saveables; | return names_to_saveables; | ||||
| } | } | ||||
| public static IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_objects_from_trackable(Trackable obj) | |||||
| public static IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> saveable_objects_from_trackable(Trackable obj) | |||||
| { | { | ||||
| // skip the process of type `PythonState` | // skip the process of type `PythonState` | ||||
| Maybe<BaseResourceVariable, MySaveableObject> create_saveable(string name = "") | |||||
| OneOf<BaseResourceVariable, MySaveableObject> create_saveable(string name = "") | |||||
| { | { | ||||
| // skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`. | // skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`. | ||||
| var tensor_dict = obj.serialize_to_tensors(); | var tensor_dict = obj.serialize_to_tensors(); | ||||
| @@ -272,14 +272,14 @@ namespace Tensorflow | |||||
| string spec_name = name + TrackableUtils.escape_local_name(tensor_name); | string spec_name = name + TrackableUtils.escape_local_name(tensor_name); | ||||
| IDictionary<string, Tensor> internal_dict; | IDictionary<string, Tensor> internal_dict; | ||||
| if (maybe_tensor.TryGet<Tensor>(out var tensor)) | |||||
| if (maybe_tensor.TryPickT0(out var tensor, out var dic)) | |||||
| { | { | ||||
| internal_dict = new Dictionary<string, Tensor>(); | internal_dict = new Dictionary<string, Tensor>(); | ||||
| internal_dict[""] = tensor; | internal_dict[""] = tensor; | ||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| internal_dict = maybe_tensor.GetValue<IDictionary<string, Tensor>>(); | |||||
| internal_dict = dic; | |||||
| } | } | ||||
| foreach (var item in internal_dict) | foreach (var item in internal_dict) | ||||
| @@ -292,7 +292,7 @@ namespace Tensorflow | |||||
| if (trackable_has_serialize_to_tensor(obj)) | if (trackable_has_serialize_to_tensor(obj)) | ||||
| { | { | ||||
| Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> res = new(); | |||||
| Dictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> res = new(); | |||||
| res[TrackableUtils.SERIALIZE_TO_TENSORS_NAME] = create_saveable; | res[TrackableUtils.SERIALIZE_TO_TENSORS_NAME] = create_saveable; | ||||
| return res; | return res; | ||||
| } | } | ||||
| @@ -316,9 +316,9 @@ namespace Tensorflow | |||||
| /// Converts a list of SaveableObjects to a tensor dictionary. | /// Converts a list of SaveableObjects to a tensor dictionary. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="saveables"></param> | /// <param name="saveables"></param> | ||||
| public static Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> saveable_object_to_tensor_dict(IList<MySaveableObject> saveables) | |||||
| public static Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> saveable_object_to_tensor_dict(IList<MySaveableObject> saveables) | |||||
| { | { | ||||
| Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new(); | |||||
| Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> tensor_dict = new(); | |||||
| foreach (var saveable in saveables) | foreach (var saveable in saveables) | ||||
| { | { | ||||
| foreach (var spec in saveable.specs) | foreach (var spec in saveable.specs) | ||||
| @@ -328,7 +328,7 @@ namespace Tensorflow | |||||
| var slice_spec = convert_to_string(spec.slice_spec); | var slice_spec = convert_to_string(spec.slice_spec); | ||||
| if (!string.IsNullOrEmpty(slice_spec)) | if (!string.IsNullOrEmpty(slice_spec)) | ||||
| { | { | ||||
| tensor_dict.SetDefault(name, new Dictionary<string, Tensor>()).GetValue<IDictionary<string, Tensor>>()[slice_spec] = spec.tensor; | |||||
| tensor_dict.SetDefault(name, new Dictionary<string, Tensor>()).AsT1[slice_spec] = spec.tensor; | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -343,7 +343,7 @@ namespace Tensorflow | |||||
| /// Generates `Trackable._restore_from_tensors` from SaveableObjects. | /// Generates `Trackable._restore_from_tensors` from SaveableObjects. | ||||
| /// </summary> | /// </summary> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Func<IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>, IDictionary<string, Operation>> saveable_object_to_restore_fn(IList<MySaveableObject> saveables) | |||||
| public static Func<IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>, IDictionary<string, Operation>> saveable_object_to_restore_fn(IList<MySaveableObject> saveables) | |||||
| { | { | ||||
| return (restored_tensors) => | return (restored_tensors) => | ||||
| { | { | ||||
| @@ -359,14 +359,14 @@ namespace Tensorflow | |||||
| var maybe_tensor = restored_tensors[name]; | var maybe_tensor = restored_tensors[name]; | ||||
| IDictionary<string, Tensor> dict; | IDictionary<string, Tensor> dict; | ||||
| if(maybe_tensor.TryGet<Tensor>(out var tensor)) | |||||
| if(maybe_tensor.TryPickT0(out var tensor, out var dic)) | |||||
| { | { | ||||
| dict = new Dictionary<string, Tensor>(); | dict = new Dictionary<string, Tensor>(); | ||||
| dict[""] = tensor; | dict[""] = tensor; | ||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| dict = maybe_tensor.GetValue<IDictionary<string, Tensor>>(); | |||||
| dict = dic; | |||||
| } | } | ||||
| saveable_restored_tensors.Add(dict[slice_spec]); | saveable_restored_tensors.Add(dict[slice_spec]); | ||||
| } | } | ||||
| @@ -381,18 +381,18 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="saveable_fn_by_name"></param> | /// <param name="saveable_fn_by_name"></param> | ||||
| /// <param name="temp_session"></param> | /// <param name="temp_session"></param> | ||||
| public static IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> recreate_saveable_objects( | |||||
| public static IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> recreate_saveable_objects( | |||||
| IDictionary<string, (Trackable, Trackable)> saveable_fn_by_name, IEnumerable<object>? temp_session) | IDictionary<string, (Trackable, Trackable)> saveable_fn_by_name, IEnumerable<object>? temp_session) | ||||
| { | { | ||||
| if (saveable_fn_by_name.Count > 0) | if (saveable_fn_by_name.Count > 0) | ||||
| { | { | ||||
| throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | ||||
| } | } | ||||
| var res = new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>(); | |||||
| var res = new Dictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>>(); | |||||
| return res; | return res; | ||||
| } | } | ||||
| public static Maybe<BaseResourceVariable, MySaveableObject> create_saveable_object(string name, string key, Func<string, Maybe<BaseResourceVariable, MySaveableObject>> factory, | |||||
| public static OneOf<BaseResourceVariable, MySaveableObject> create_saveable_object(string name, string key, Func<string, OneOf<BaseResourceVariable, MySaveableObject>> factory, | |||||
| bool call_with_mapped_captures = false) | bool call_with_mapped_captures = false) | ||||
| { | { | ||||
| return factory(key); | return factory(key); | ||||
| @@ -412,7 +412,7 @@ namespace Tensorflow | |||||
| public object Obj => _obj; | public object Obj => _obj; | ||||
| public IList<MySaveableObject> mySaveables=> _saveables; | public IList<MySaveableObject> mySaveables=> _saveables; | ||||
| public override IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors() | |||||
| public override IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors() | |||||
| { | { | ||||
| return saveable_object_util.saveable_object_to_tensor_dict(_saveables); | return saveable_object_util.saveable_object_to_tensor_dict(_saveables); | ||||
| } | } | ||||
| @@ -422,7 +422,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="restored_tensors"></param> | /// <param name="restored_tensors"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public override IDictionary<string, Operation> _restore_from_tensors(IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors) | |||||
| public override IDictionary<string, Operation> _restore_from_tensors(IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> restored_tensors) | |||||
| { | { | ||||
| List<string> expected_keys = new(); | List<string> expected_keys = new(); | ||||
| foreach(var saveable in _saveables) | foreach(var saveable in _saveables) | ||||
| @@ -14,6 +14,7 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using OneOf; | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Diagnostics; | using System.Diagnostics; | ||||
| @@ -43,8 +44,8 @@ namespace Tensorflow.Train | |||||
| protected IList<TrackableReference> _unconditional_checkpoint_dependencies; | protected IList<TrackableReference> _unconditional_checkpoint_dependencies; | ||||
| protected Dictionary<string, IList<CheckpointPosition>> _unconditional_deferred_dependencies; | protected Dictionary<string, IList<CheckpointPosition>> _unconditional_deferred_dependencies; | ||||
| protected IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> _self_saveable_object_factories = | |||||
| new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>(); | |||||
| protected IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> _self_saveable_object_factories = | |||||
| new Dictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>>(); | |||||
| private bool _manual_tracking = true; | private bool _manual_tracking = true; | ||||
| private static Trackable _none = new AutoTrackable(); | private static Trackable _none = new AutoTrackable(); | ||||
| @@ -73,7 +74,7 @@ namespace Tensorflow.Train | |||||
| public IDictionary<string, Trackable> UnconditionalDependencyNames { get => _unconditional_dependency_names; } | public IDictionary<string, Trackable> UnconditionalDependencyNames { get => _unconditional_dependency_names; } | ||||
| public IList<TrackableReference> CheckpointDependencies { get => UnconditionalCheckpointDependencies; } | public IList<TrackableReference> CheckpointDependencies { get => UnconditionalCheckpointDependencies; } | ||||
| public Dictionary<string, IList<CheckpointPosition>> DeferredDependencies => _unconditional_deferred_dependencies; | public Dictionary<string, IList<CheckpointPosition>> DeferredDependencies => _unconditional_deferred_dependencies; | ||||
| public IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> SelfSaveableObjectFactories | |||||
| public IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> SelfSaveableObjectFactories | |||||
| { | { | ||||
| get | get | ||||
| { | { | ||||
| @@ -249,9 +250,9 @@ namespace Tensorflow.Train | |||||
| return self_tensor_map.Keys.ToList(); | return self_tensor_map.Keys.ToList(); | ||||
| } | } | ||||
| public virtual IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint() | |||||
| public virtual IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint() | |||||
| { | { | ||||
| Maybe<BaseResourceVariable, MySaveableObject> create_saveable(string name = "") | |||||
| OneOf<BaseResourceVariable, MySaveableObject> create_saveable(string name = "") | |||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| //return new TrackableSaveable(this, null, name, null, null); | //return new TrackableSaveable(this, null, name, null, null); | ||||
| @@ -259,7 +260,7 @@ namespace Tensorflow.Train | |||||
| if (saveable_object_util.trackable_has_serialize_to_tensor(this)) | if (saveable_object_util.trackable_has_serialize_to_tensor(this)) | ||||
| { | { | ||||
| // TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`). | // TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`). | ||||
| Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> res = new(); | |||||
| Dictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> res = new(); | |||||
| res[""] = create_saveable; | res[""] = create_saveable; | ||||
| return res; | return res; | ||||
| } | } | ||||
| @@ -278,12 +279,12 @@ namespace Tensorflow.Train | |||||
| /// </summary> | /// </summary> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| /// <exception cref="NotImplementedException"></exception> | /// <exception cref="NotImplementedException"></exception> | ||||
| public virtual IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors() | |||||
| public virtual IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors() | |||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| } | } | ||||
| public virtual IDictionary<string, Operation> _restore_from_tensors(IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors) | |||||
| public virtual IDictionary<string, Operation> _restore_from_tensors(IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> restored_tensors) | |||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| } | } | ||||
| @@ -0,0 +1,23 @@ | |||||
| using Google.Protobuf; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Util | |||||
| { | |||||
| internal static class function_utils | |||||
| { | |||||
| private static string _rewriter_config_optimizer_disabled; | |||||
| public static string get_disabled_rewriter_config() | |||||
| { | |||||
| if(_rewriter_config_optimizer_disabled is null) | |||||
| { | |||||
| var config = new ConfigProto(); | |||||
| var rewriter_config = config.GraphOptions.RewriteOptions; | |||||
| rewriter_config.DisableMetaOptimizer = true; | |||||
| _rewriter_config_optimizer_disabled = config.ToString(); | |||||
| } | |||||
| return _rewriter_config_optimizer_disabled; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -8,6 +8,7 @@ using System.Collections.Generic; | |||||
| using System.Diagnostics; | using System.Diagnostics; | ||||
| using Tensorflow.Checkpoint; | using Tensorflow.Checkpoint; | ||||
| using Tensorflow.Training.Saving.SavedModel; | using Tensorflow.Training.Saving.SavedModel; | ||||
| using OneOf; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -155,7 +156,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| variable_accessed(this); | variable_accessed(this); | ||||
| var result = gen_resource_variable_ops.read_variable_op(handle, _dtype); | var result = gen_resource_variable_ops.read_variable_op(handle, _dtype); | ||||
| // _maybe_set_handle_data(_dtype, _handle, result); | |||||
| resource_variable_ops._maybe_set_handle_data(_dtype, handle, result); | |||||
| // have to set shape when converting to substituent placeholder | // have to set shape when converting to substituent placeholder | ||||
| if (result.shape.ndim == -1) | if (result.shape.ndim == -1) | ||||
| @@ -293,9 +294,9 @@ namespace Tensorflow | |||||
| resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options); | resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options); | ||||
| } | } | ||||
| public override IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint() | |||||
| public override IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint() | |||||
| { | { | ||||
| var res = new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>(); | |||||
| var res = new Dictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>>(); | |||||
| res[Trackable.Constants.VARIABLE_VALUE_KEY] = x => this; | res[Trackable.Constants.VARIABLE_VALUE_KEY] = x => this; | ||||
| return res; | return res; | ||||
| } | } | ||||
| @@ -124,7 +124,9 @@ namespace Tensorflow | |||||
| initializer_op = gen_state_ops.assign(handle, _initial_value, true).op; | initializer_op = gen_state_ops.assign(handle, _initial_value, true).op; | ||||
| ops.colocate_with(initializer_op); | ops.colocate_with(initializer_op); | ||||
| tf.device(handle.Device); | |||||
| var value = gen_resource_variable_ops.read_variable_op(handle, dtype); | |||||
| resource_variable_ops._maybe_set_handle_data(dtype, handle, value); | |||||
| _graph_element = gen_array_ops.identity(handle, name = "read"); | _graph_element = gen_array_ops.identity(handle, name = "read"); | ||||
| ops.add_to_collections<IVariableV1>(collections, this); | ops.add_to_collections<IVariableV1>(collections, this); | ||||
| _dtype = handle.dtype; | _dtype = handle.dtype; | ||||
| @@ -141,6 +143,12 @@ namespace Tensorflow | |||||
| gen_resource_variable_ops.assign_variable_op(handle, _initial_value); | gen_resource_variable_ops.assign_variable_op(handle, _initial_value); | ||||
| initializer_op = null; | initializer_op = null; | ||||
| _graph_element = null; | _graph_element = null; | ||||
| if (!string.IsNullOrEmpty(caching_device)) | |||||
| { | |||||
| tf.device(caching_device); | |||||
| var value = gen_resource_variable_ops.read_variable_op(handle, dtype); | |||||
| resource_variable_ops._maybe_set_handle_data(dtype, handle, value); | |||||
| } | |||||
| _dtype = _initial_value.dtype.as_base_dtype(); | _dtype = _initial_value.dtype.as_base_dtype(); | ||||
| // initial_value = _in_graph_mode ? initial_value : null; | // initial_value = _in_graph_mode ? initial_value : null; | ||||
| } | } | ||||
| @@ -9,7 +9,7 @@ namespace Tensorflow.Variables | |||||
| /// <summary> | /// <summary> | ||||
| /// A variable with no initializer. | /// A variable with no initializer. | ||||
| /// </summary> | /// </summary> | ||||
| public sealed class UninitializedVariable: BaseResourceVariable | |||||
| public sealed class UninitializedVariable: BaseResourceVariable, IVariableV1 | |||||
| { | { | ||||
| // TODO: complete the arg list. | // TODO: complete the arg list. | ||||
| public UninitializedVariable( | public UninitializedVariable( | ||||
| @@ -23,6 +23,7 @@ namespace Tensorflow.Variables | |||||
| { | { | ||||
| string unique_id = ""; | string unique_id = ""; | ||||
| string handle_name = ""; | string handle_name = ""; | ||||
| Tensor created_handle = null; | |||||
| tf_with(ops.init_scope(), (x) => | tf_with(ops.init_scope(), (x) => | ||||
| { | { | ||||
| _in_graph_mode = !tf.Context.executing_eagerly(); | _in_graph_mode = !tf.Context.executing_eagerly(); | ||||
| @@ -40,7 +41,7 @@ namespace Tensorflow.Variables | |||||
| unique_id = $"{handle_name}-{ops.uid()}"; | unique_id = $"{handle_name}-{ops.uid()}"; | ||||
| shared_name = null; | shared_name = null; | ||||
| } | } | ||||
| var handle = resource_variable_ops.variable_handle_from_shape_and_dtype( | |||||
| created_handle = resource_variable_ops.variable_handle_from_shape_and_dtype( | |||||
| shape, dtype, shared_name, name, _in_graph_mode, extra_handle_data); | shape, dtype, shared_name, name, _in_graph_mode, extra_handle_data); | ||||
| // skip the assignment of `handle._parent_trackable` because of lack of API. | // skip the assignment of `handle._parent_trackable` because of lack of API. | ||||
| // skip the assignment of `handle._name` and `handle._unique_id` because of accessability. | // skip the assignment of `handle._name` and `handle._unique_id` because of accessability. | ||||
| @@ -51,7 +52,7 @@ namespace Tensorflow.Variables | |||||
| { | { | ||||
| tf.device(handle.Device); | tf.device(handle.Device); | ||||
| var value = gen_resource_variable_ops.read_variable_op(handle, dtype); | var value = gen_resource_variable_ops.read_variable_op(handle, dtype); | ||||
| // _maybe_set_handle_data(dtype, handle, value) | |||||
| resource_variable_ops._maybe_set_handle_data(dtype, handle, value); | |||||
| _graph_element = value; | _graph_element = value; | ||||
| }); | }); | ||||
| ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this); | ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this); | ||||
| @@ -64,7 +65,7 @@ namespace Tensorflow.Variables | |||||
| }); | }); | ||||
| _shape = shape; | _shape = shape; | ||||
| _dtype = dtype; | _dtype = dtype; | ||||
| base.__init__(trainable, handle, unique_id: unique_id, handle_name: handle_name); | |||||
| base.__init__(trainable, created_handle, unique_id: unique_id, handle_name: handle_name); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -26,6 +26,7 @@ using Tensorflow.Eager; | |||||
| using Tensorflow.Graphs; | using Tensorflow.Graphs; | ||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using static Tensorflow.CppShapeInferenceResult.Types; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -572,9 +573,12 @@ namespace Tensorflow | |||||
| return get_default_graph().building_function; | return get_default_graph().building_function; | ||||
| } | } | ||||
| public static SafeTensorHandle get_resource_handle_data(Tensor graph_op) | |||||
| public static HandleData get_resource_handle_data(Tensor graph_op) | |||||
| { | { | ||||
| throw new NotImplementedException(); | |||||
| // This implementation hasn't been checked for some reasons. | |||||
| // If it throws an exception in the future, please check it. | |||||
| var handle_data = c_api.GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); | |||||
| return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data))); | |||||
| } | } | ||||
| public static void dismantle_graph(Graph graph) | public static void dismantle_graph(Graph graph) | ||||
| @@ -40,7 +40,7 @@ namespace Tensorflow.Keras.Engine | |||||
| /// <summary> | /// <summary> | ||||
| /// Arguments initialize layer. | /// Arguments initialize layer. | ||||
| /// </summary> | /// </summary> | ||||
| LayerArgs args; | |||||
| internal LayerArgs args; | |||||
| /// <summary> | /// <summary> | ||||
| /// Indicates whether `build` needs to be called upon layer call, to create | /// Indicates whether `build` needs to be called upon layer call, to create | ||||
| @@ -147,7 +147,7 @@ namespace Tensorflow.Keras.Engine | |||||
| List<INode> outboundNodes; | List<INode> outboundNodes; | ||||
| public List<INode> OutboundNodes => outboundNodes; | public List<INode> OutboundNodes => outboundNodes; | ||||
| public JObject SerializedAttributes { get; set; } | |||||
| public Dictionary<string, object> SerializedAttributes { get; set; } | |||||
| ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>(); | ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>(); | ||||
| public CallContext CallContext => callContext.Value; | public CallContext CallContext => callContext.Value; | ||||
| @@ -26,7 +26,7 @@ namespace Tensorflow.Keras.Saving | |||||
| { | { | ||||
| public class KerasObjectLoader | public class KerasObjectLoader | ||||
| { | { | ||||
| internal static readonly IDictionary<string, Trackable> PUBLIC_ATTRIBUTES = new CommonEndPoints().CheckpointableObjects; | |||||
| internal static readonly IDictionary<string, Trackable> PUBLIC_ATTRIBUTES; | |||||
| private SavedMetadata _metadata; | private SavedMetadata _metadata; | ||||
| private SavedObjectGraph _proto; | private SavedObjectGraph _proto; | ||||
| private Dictionary<int, string> _node_paths = new Dictionary<int, string>(); | private Dictionary<int, string> _node_paths = new Dictionary<int, string>(); | ||||
| @@ -39,7 +39,13 @@ namespace Tensorflow.Keras.Saving | |||||
| static KerasObjectLoader() | static KerasObjectLoader() | ||||
| { | { | ||||
| PUBLIC_ATTRIBUTES[Keras.Saving.SavedModel.Constants.KERAS_ATTR] = null; | |||||
| var endPoints = new CommonEndPoints(); | |||||
| PUBLIC_ATTRIBUTES = new Dictionary<string, Trackable>(); | |||||
| foreach (var key in endPoints._all_checkpointable_objects.Concat(endPoints._all_functions)) | |||||
| { | |||||
| PUBLIC_ATTRIBUTES[key] = null; | |||||
| } | |||||
| PUBLIC_ATTRIBUTES[SavedModel.Constants.KERAS_ATTR] = null; | |||||
| } | } | ||||
| public KerasObjectLoader(SavedMetadata metadata, SavedObjectGraph object_graph_def) | public KerasObjectLoader(SavedMetadata metadata, SavedObjectGraph object_graph_def) | ||||
| @@ -125,8 +131,14 @@ namespace Tensorflow.Keras.Saving | |||||
| continue; | continue; | ||||
| } | } | ||||
| // TODO: deal with `RevivedLayer` and `RevivedInputLayer`. | |||||
| layers_revived_from_config.Add(node as Layer); | |||||
| if(node is RevivedLayer or RevivedInputLayer) | |||||
| { | |||||
| layers_revived_from_saved_model.Add(node as Layer); | |||||
| } | |||||
| else | |||||
| { | |||||
| layers_revived_from_config.Add(node as Layer); | |||||
| } | |||||
| } | } | ||||
| _finalize_saved_model_layers(layers_revived_from_saved_model); | _finalize_saved_model_layers(layers_revived_from_saved_model); | ||||
| @@ -171,10 +183,13 @@ namespace Tensorflow.Keras.Saving | |||||
| // TODO(Rinne): implement it | // TODO(Rinne): implement it | ||||
| } | } | ||||
| } | } | ||||
| // `model.__init__(layers, config["name"])` | |||||
| s.InitLayers(layers); | |||||
| s.Name = config["name"].ToObject<string>(); | |||||
| // `model.__init__(layers, config["name"])`InitLayers(layers); | |||||
| s = new Sequential(new SequentialArgs(){ | |||||
| Layers = layers.Select(x => x as ILayer).ToList(), | |||||
| Name = config["name"].ToObject<string>() | |||||
| }); | |||||
| //s.Name = config["name"].ToObject<string>(); | |||||
| if(s.input is null || s.input.Length == 0) | if(s.input is null || s.input.Length == 0) | ||||
| { | { | ||||
| var first_layer = _get_child_layer_node_ids(model_id)[0]; | var first_layer = _get_child_layer_node_ids(model_id)[0]; | ||||
| @@ -205,7 +220,12 @@ namespace Tensorflow.Keras.Saving | |||||
| private void _set_network_attributes_from_metadata(Model revived_object) | private void _set_network_attributes_from_metadata(Model revived_object) | ||||
| { | { | ||||
| // TODO: implement it. | |||||
| var metadata = revived_object.SerializedAttributes["matadata"] as JObject; | |||||
| if (metadata.ContainsKey("dtype")) | |||||
| { | |||||
| // TODO(Rinne): set_dtype_policy. | |||||
| } | |||||
| revived_object.args.Trainable = metadata["trainable"].Value<bool>(); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -330,7 +350,7 @@ namespace Tensorflow.Keras.Saving | |||||
| private (Trackable, Action<object, object, object>) _revive_from_config(string identifier, KerasMetaData metadata, int node_id) | private (Trackable, Action<object, object, object>) _revive_from_config(string identifier, KerasMetaData metadata, int node_id) | ||||
| { | { | ||||
| Trackable obj; | Trackable obj; | ||||
| if(identifier == Keras.Saving.SavedModel.Constants.METRIC_IDENTIFIER) | |||||
| if(identifier == SavedModel.Constants.METRIC_IDENTIFIER) | |||||
| { | { | ||||
| // TODO(Rinne): implement it. | // TODO(Rinne): implement it. | ||||
| return (null, null); | return (null, null); | ||||
| @@ -429,25 +449,26 @@ namespace Tensorflow.Keras.Saving | |||||
| return obj; | return obj; | ||||
| } | } | ||||
| private void _revive_setter(object layer, object name, object value) | |||||
| private void _revive_setter(object obj, object name, object value) | |||||
| { | { | ||||
| Debug.Assert(name is string); | Debug.Assert(name is string); | ||||
| Debug.Assert(layer is Layer); | |||||
| Debug.Assert(obj is Layer); | |||||
| Layer layer = (Layer)obj; | |||||
| if(PUBLIC_ATTRIBUTES.ContainsKey(name as string)) | if(PUBLIC_ATTRIBUTES.ContainsKey(name as string)) | ||||
| { | { | ||||
| if(value is Trackable) | if(value is Trackable) | ||||
| { | { | ||||
| (layer as Layer)._track_trackable(value as Trackable, name as string); | |||||
| layer._track_trackable(value as Trackable, name as string); | |||||
| } | } | ||||
| if((layer as Layer).SerializedAttributes is null) | |||||
| if(layer.SerializedAttributes is null) | |||||
| { | { | ||||
| (layer as Layer).SerializedAttributes = new JObject(); | |||||
| layer.SerializedAttributes = new Dictionary<string, object>(); | |||||
| } | } | ||||
| (layer as Layer).SerializedAttributes[name as string] = JToken.FromObject(value); | |||||
| layer.SerializedAttributes[name as string] = value; | |||||
| } | } | ||||
| else if(layer is Functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success) | |||||
| else if(layer is Functional functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success) | |||||
| { | { | ||||
| (layer as Functional)._track_trackable(value as Trackable, name as string, overwrite: true); | |||||
| functional._track_trackable(value as Trackable, name as string, overwrite: true); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -521,7 +542,7 @@ namespace Tensorflow.Keras.Saving | |||||
| } | } | ||||
| var metric_list_node_id = _search_for_child_node(node_id, new string[] { | var metric_list_node_id = _search_for_child_node(node_id, new string[] { | ||||
| Keras.Saving.SavedModel.Constants.KERAS_ATTR, "layer_metrics" | |||||
| SavedModel.Constants.KERAS_ATTR, "layer_metrics" | |||||
| }); | }); | ||||
| if(metric_list_node_id is not null && obj is Model model && model.metrics is not null) | if(metric_list_node_id is not null && obj is Model model && model.metrics is not null) | ||||
| { | { | ||||
| @@ -547,7 +568,7 @@ namespace Tensorflow.Keras.Saving | |||||
| // skip the check for registered identifier | // skip the check for registered identifier | ||||
| Action<object, object, object> setter; | Action<object, object, object> setter; | ||||
| if (Keras.Saving.SavedModel.Constants.KERAS_OBJECT_IDENTIFIERS.Contains(obj_child.ObjectIdentifier)) | |||||
| if (SavedModel.Constants.KERAS_OBJECT_IDENTIFIERS.Contains(obj_child.ObjectIdentifier)) | |||||
| { | { | ||||
| setter = _revive_setter; | setter = _revive_setter; | ||||
| } | } | ||||
| @@ -659,7 +680,23 @@ namespace Tensorflow.Keras.Saving | |||||
| private void _maybe_add_serialized_attributes(Layer layer, KerasMetaData metadata) | private void _maybe_add_serialized_attributes(Layer layer, KerasMetaData metadata) | ||||
| { | { | ||||
| // TODO: deal with `RevivedLayer` | |||||
| if(layer.SerializedAttributes is null || layer.SerializedAttributes.Count == 0) | |||||
| { | |||||
| layer.SerializedAttributes = new Dictionary<string, object>(); | |||||
| layer.SerializedAttributes["metadata"] = metadata; | |||||
| } | |||||
| } | |||||
| private static object _get_keras_attr(Layer layer) | |||||
| { | |||||
| if((layer.SerializedAttributes ?? new Dictionary<string, object>()).TryGetValue(SavedModel.Constants.KERAS_ATTR, out var value)) | |||||
| { | |||||
| return value; | |||||
| } | |||||
| else | |||||
| { | |||||
| return null; | |||||
| } | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -24,17 +24,22 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||||
| } | } | ||||
| } | } | ||||
| public static void _revive_setter(object layer, object name, object value) | |||||
| public static void _revive_setter(object obj, object name, object value) | |||||
| { | { | ||||
| Debug.Assert(name is string); | Debug.Assert(name is string); | ||||
| Debug.Assert(layer is Layer); | |||||
| Debug.Assert(obj is Layer); | |||||
| Layer layer = (Layer)obj; | |||||
| if (KerasObjectLoader.PUBLIC_ATTRIBUTES.ContainsKey(name as string)) | if (KerasObjectLoader.PUBLIC_ATTRIBUTES.ContainsKey(name as string)) | ||||
| { | { | ||||
| if (value is Trackable trackable) | if (value is Trackable trackable) | ||||
| { | { | ||||
| (layer as Layer)._track_trackable(trackable, name as string); | |||||
| layer._track_trackable(trackable, name as string); | |||||
| } | } | ||||
| (layer as Layer).SerializedAttributes[name] = JToken.FromObject(value); | |||||
| if (layer.SerializedAttributes is null) | |||||
| { | |||||
| layer.SerializedAttributes = new Dictionary<string, object>(); | |||||
| } | |||||
| layer.SerializedAttributes[name as string] = value; | |||||
| } | } | ||||
| else if (layer is Functional functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success) | else if (layer is Functional functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success) | ||||
| { | { | ||||
| @@ -0,0 +1,15 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Keras.Engine; | |||||
| namespace Tensorflow.Keras.Saving.SavedModel | |||||
| { | |||||
| public class RevivedInputLayer: Layer | |||||
| { | |||||
| private RevivedInputLayer(): base(null) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -55,6 +55,21 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||||
| private RevivedConfig _config = null; | private RevivedConfig _config = null; | ||||
| public object keras_api | |||||
| { | |||||
| get | |||||
| { | |||||
| if (SerializedAttributes.TryGetValue(SavedModel.Constants.KERAS_ATTR, out var value)) | |||||
| { | |||||
| return value; | |||||
| } | |||||
| else | |||||
| { | |||||
| return null; | |||||
| } | |||||
| } | |||||
| } | |||||
| public RevivedLayer(LayerArgs args): base(args) | public RevivedLayer(LayerArgs args): base(args) | ||||
| { | { | ||||
| @@ -69,5 +84,17 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||||
| { | { | ||||
| return _config; | return _config; | ||||
| } | } | ||||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
| { | |||||
| if(SerializedAttributes is null || !SerializedAttributes.TryGetValue("__call__", out var func) || func is not Function) | |||||
| { | |||||
| return base.Call(inputs, state, training); | |||||
| } | |||||
| else | |||||
| { | |||||
| return (func as Function).Apply(inputs); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -19,8 +19,8 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||||
| protected IDictionary<string, Trackable?> _object_dict; | protected IDictionary<string, Trackable?> _object_dict; | ||||
| protected IDictionary<string, Trackable?> _function_dict; | protected IDictionary<string, Trackable?> _function_dict; | ||||
| protected AutoTrackable _keras_trackable; | protected AutoTrackable _keras_trackable; | ||||
| protected HashSet<string> _all_functions; | |||||
| protected HashSet<string> _all_checkpointable_objects; | |||||
| internal HashSet<string> _all_functions; | |||||
| internal HashSet<string> _all_checkpointable_objects; | |||||
| private SerializedAttributes() | private SerializedAttributes() | ||||
| { | { | ||||
| @@ -197,19 +197,15 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||||
| public class CommonEndPoints: SerializedAttributes | public class CommonEndPoints: SerializedAttributes | ||||
| { | { | ||||
| public CommonEndPoints(IEnumerable<string> checkpointable_objects, IEnumerable<string> functions) : | public CommonEndPoints(IEnumerable<string> checkpointable_objects, IEnumerable<string> functions) : | ||||
| //base(checkpointable_objects.Concat(new string[] { "variables", "trainable_variables", "regularization_losses" }), | |||||
| // functions.Concat(new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" })) | |||||
| base(checkpointable_objects.Concat(new string[] { "variables", "trainable_variables"}), | |||||
| functions.Concat(new string[] { })) | |||||
| base(checkpointable_objects.Concat(new string[] { "variables", "trainable_variables", "regularization_losses" }), | |||||
| functions.Concat(new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" })) | |||||
| { | { | ||||
| } | } | ||||
| public CommonEndPoints() : | public CommonEndPoints() : | ||||
| //base(new string[] { "variables", "trainable_variables", "regularization_losses" }, | |||||
| // new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" }) | |||||
| base(new string[] { "variables", "trainable_variables"}, | |||||
| new string[] {}) | |||||
| base(new string[] { "variables", "trainable_variables", "regularization_losses" }, | |||||
| new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" }) | |||||
| { | { | ||||
| } | } | ||||