| @@ -1,7 +1,7 @@ | |||
| | |||
| Microsoft Visual Studio Solution File, Format Version 12.00 | |||
| # Visual Studio Version 16 | |||
| VisualStudioVersion = 16.0.31624.102 | |||
| # Visual Studio Version 17 | |||
| VisualStudioVersion = 17.4.33213.308 | |||
| MinimumVisualStudioVersion = 10.0.40219.1 | |||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Binding", "src\TensorFlowNET.Core\Tensorflow.Binding.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}" | |||
| EndProject | |||
| @@ -23,6 +23,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras.UnitTest", | |||
| EndProject | |||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Graph.UnitTest", "test\TensorFlowNET.Graph.UnitTest\TensorFlowNET.Graph.UnitTest.csproj", "{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}" | |||
| EndProject | |||
| Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Tensorflow.Common", "Tensorflow.Common\Tensorflow.Common.csproj", "{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}" | |||
| EndProject | |||
| Global | |||
| GlobalSection(SolutionConfigurationPlatforms) = preSolution | |||
| Debug|Any CPU = Debug|Any CPU | |||
| @@ -153,6 +155,18 @@ Global | |||
| {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.Build.0 = Release|x64 | |||
| {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.ActiveCfg = Release|Any CPU | |||
| {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.Build.0 = Release|Any CPU | |||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|x64.ActiveCfg = Debug|Any CPU | |||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|x64.Build.0 = Debug|Any CPU | |||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|x86.ActiveCfg = Debug|Any CPU | |||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|x86.Build.0 = Debug|Any CPU | |||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|Any CPU.Build.0 = Release|Any CPU | |||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|x64.ActiveCfg = Release|Any CPU | |||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|x64.Build.0 = Release|Any CPU | |||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|x86.ActiveCfg = Release|Any CPU | |||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|x86.Build.0 = Release|Any CPU | |||
| EndGlobalSection | |||
| GlobalSection(SolutionProperties) = preSolution | |||
| HideSolutionNode = FALSE | |||
| @@ -0,0 +1,13 @@ | |||
| using OneOf; | |||
| using System; | |||
| namespace Tensorflow.Common.Extensions | |||
| { | |||
| public static class OneofExtension | |||
| { | |||
| public static bool IsTypeOrDeriveFrom<T>(this IOneOf src) | |||
| { | |||
| return src.Value is T; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,11 @@ | |||
| <Project Sdk="Microsoft.NET.Sdk"> | |||
| <PropertyGroup> | |||
| <TargetFramework>netstandard2.0</TargetFramework> | |||
| </PropertyGroup> | |||
| <ItemGroup> | |||
| <PackageReference Include="OneOf" Version="3.0.223" /> | |||
| </ItemGroup> | |||
| </Project> | |||
| @@ -1,10 +1,12 @@ | |||
| using System; | |||
| using OneOf; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training; | |||
| using Tensorflow.Common.Extensions; | |||
| using pbc = global::Google.Protobuf.Collections; | |||
| namespace Tensorflow.Checkpoint | |||
| @@ -28,7 +30,7 @@ namespace Tensorflow.Checkpoint | |||
| ); | |||
| public static class SaveUtil | |||
| { | |||
| public static (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||
| public static (IDictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||
| serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null) | |||
| { | |||
| var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); | |||
| @@ -117,16 +119,16 @@ namespace Tensorflow.Checkpoint | |||
| /// <param name="call_with_mapped_captures"></param> | |||
| /// <param name="cache"></param> | |||
| /// <param name="object_graph_proto"></param> | |||
| private static IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids, | |||
| private static IDictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids, | |||
| bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) | |||
| { | |||
| Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new(); | |||
| Dictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new(); | |||
| foreach(var td in tensor_trackables) | |||
| { | |||
| // TODO: deal with cache. | |||
| var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; | |||
| Trackable trackable = null; | |||
| IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict; | |||
| IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> tensor_dict; | |||
| if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) | |||
| { | |||
| (trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); | |||
| @@ -148,12 +150,12 @@ namespace Tensorflow.Checkpoint | |||
| return serialized_tensors; | |||
| } | |||
| private static IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||
| private static IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||
| { | |||
| var trackable = trackable_data.object_to_save; | |||
| // TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. | |||
| IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> ret_tensor_dict; | |||
| IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> ret_tensor_dict; | |||
| if (call_with_mapped_captures) | |||
| { | |||
| throw new NotImplementedException(); | |||
| @@ -164,7 +166,7 @@ namespace Tensorflow.Checkpoint | |||
| } | |||
| // TODO: deal with the type `SaveSpce` (currently it will never be it). | |||
| Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new(); | |||
| Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> tensor_dict = new(); | |||
| foreach(var pair in ret_tensor_dict) | |||
| { | |||
| var local_name = TrackableUtils.escape_local_name(pair.Key); | |||
| @@ -200,7 +202,7 @@ namespace Tensorflow.Checkpoint | |||
| /// <param name="call_with_mapped_captures"></param> | |||
| /// <param name="object_graph_proto"></param> | |||
| /// <returns></returns> | |||
| private static (Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids, | |||
| private static (Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids, | |||
| bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||
| { | |||
| Dictionary<Trackable, string> object_names = new(); | |||
| @@ -8,6 +8,7 @@ using Tensorflow.Training; | |||
| using pbc = global::Google.Protobuf.Collections; | |||
| using static Tensorflow.Binding; | |||
| using Google.Protobuf; | |||
| using OneOf; | |||
| namespace Tensorflow.Checkpoint; | |||
| @@ -179,13 +180,13 @@ public static class SaveUtilV1 | |||
| // TODO: tensorflow python has a process with callable `saveable_factory`. | |||
| List<MySaveableObject> saveables = new(); | |||
| if (maybe_saveable.TryGet<MySaveableObject>(out var s)) | |||
| if (maybe_saveable.TryPickT1(out var s, out var variable)) | |||
| { | |||
| saveables.Add(s); | |||
| } | |||
| else | |||
| { | |||
| saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValue<BaseResourceVariable>() as Trackable, key)); | |||
| saveables.AddRange(saveable_object_util.saveable_objects_for_op(variable as Trackable, key)); | |||
| } | |||
| foreach (var saveable in saveables) | |||
| @@ -217,7 +218,7 @@ public static class SaveUtilV1 | |||
| public record class CheckpointFactoryData | |||
| ( | |||
| Func<string, Maybe<BaseResourceVariable, MySaveableObject>> factory, | |||
| Func<string, OneOf<BaseResourceVariable, MySaveableObject>> factory, | |||
| string name, | |||
| string checkpoint_key | |||
| ); | |||
| @@ -12,6 +12,7 @@ using static Tensorflow.Binding; | |||
| using Tensorflow.Operations; | |||
| using Newtonsoft.Json; | |||
| using Tensorflow.Training; | |||
| using OneOf; | |||
| namespace Tensorflow.Checkpoint; | |||
| @@ -49,7 +50,7 @@ public class TrackableSaver | |||
| } | |||
| private (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||
| private (IDictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||
| gather_serialized_tensors(Tensor? object_graph_tensor = null) | |||
| { | |||
| var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); | |||
| @@ -68,7 +69,7 @@ public class TrackableSaver | |||
| Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | |||
| if (!serialized_tensors.ContainsKey(Trackable.None)) | |||
| { | |||
| serialized_tensors[Trackable.None] = new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>(); | |||
| serialized_tensors[Trackable.None] = new Dictionary<string, OneOf.OneOf<Tensor, IDictionary<string, Tensor>>>(); | |||
| } | |||
| serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor; | |||
| return (serialized_tensors, feed_additions, registered_savers, graph_proto); | |||
| @@ -400,7 +401,7 @@ public class CheckpointRestoreCoordinator | |||
| // skip the callback. | |||
| } | |||
| public List<Operation> restore_saveables(Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> tensor_saveables, List<CheckpointPosition> positions, object? registered_savers = null) | |||
| public List<Operation> restore_saveables(Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> tensor_saveables, List<CheckpointPosition> positions, object? registered_savers = null) | |||
| { | |||
| List<Operation> restore_ops = new(); | |||
| foreach(var position in positions) | |||
| @@ -412,7 +413,7 @@ public class CheckpointRestoreCoordinator | |||
| Dictionary<string, BaseResourceVariable> variable_dict = new(); | |||
| foreach(var item in tensor_saveables) | |||
| { | |||
| if(item.Value.TryGet<BaseResourceVariable>(out var variable)) | |||
| if(item.Value.TryPickT0(out var variable, out var _)) | |||
| { | |||
| variable_dict[item.Key] = variable; | |||
| } | |||
| @@ -15,106 +15,14 @@ using Tensorflow.Graphs; | |||
| using System.Xml.Linq; | |||
| using System.Diagnostics; | |||
| using RestoreFunc = System.Func<object, object>; | |||
| using OneOf; | |||
| namespace Tensorflow.Checkpoint | |||
| { | |||
| public class Maybe<TA, TB> | |||
| { | |||
| private TA? _valueA = default(TA); | |||
| private TB? _valueB = default(TB); | |||
| private Type _type; | |||
| private bool _assignedTA; | |||
| public Maybe(TA value) | |||
| { | |||
| _valueA = value; | |||
| _type= typeof(TA); | |||
| _assignedTA = true; | |||
| } | |||
| public Maybe(TB value) | |||
| { | |||
| _valueB = value; | |||
| _type = typeof(TB); | |||
| _assignedTA = false; | |||
| } | |||
| public Type DataType => _type; | |||
| /// <summary> | |||
| /// Try to get the type T member of this instance. It returns true when TA or TB derive from T and is correspondingly assigned. | |||
| /// It returns | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| /// <param name="res"></param> | |||
| /// <returns></returns> | |||
| public bool TryGet<T>(out T? res) | |||
| { | |||
| if(_valueA is T && _valueB is not T) | |||
| { | |||
| res = (T)(object)_valueA; | |||
| return _assignedTA; | |||
| } | |||
| else if(_valueA is not T && _valueB is T) | |||
| { | |||
| res = (T)(object)_valueB; | |||
| return !_assignedTA; | |||
| } | |||
| res = default(T); | |||
| return false; | |||
| } | |||
| public bool IsTypeOrDeriveFrom<T>() | |||
| { | |||
| if (_valueA is T && _valueB is not T) | |||
| { | |||
| return _assignedTA; | |||
| } | |||
| else if (_valueA is not T && _valueB is T) | |||
| { | |||
| return !_assignedTA; | |||
| } | |||
| else if (_valueA is T && _valueB is T) | |||
| { | |||
| return true; | |||
| } | |||
| else | |||
| { | |||
| return false; | |||
| } | |||
| } | |||
| public T GetValue<T>() | |||
| { | |||
| if (_valueA is T && _valueB is not T) | |||
| { | |||
| return (T)(object)_valueA; | |||
| } | |||
| else if (_valueA is not T && _valueB is T) | |||
| { | |||
| return (T)(object)_valueB; | |||
| } | |||
| else if (_valueA is T && _valueB is T) | |||
| { | |||
| throw new TypeError("The type is vague, this is always because TA and TB both derive from T."); | |||
| } | |||
| else | |||
| { | |||
| throw new TypeError($"Expected {typeof(TA)} or {typeof(TB)}, but got typeof{typeof(T)}."); | |||
| } | |||
| } | |||
| public static implicit operator Maybe<TA, TB>(TA a) | |||
| { | |||
| return new Maybe<TA, TB>(a); | |||
| } | |||
| public static implicit operator Maybe<TA, TB>(TB b) | |||
| { | |||
| return new Maybe<TA, TB>(b); | |||
| } | |||
| } | |||
| internal class SingleDeviceSaver | |||
| { | |||
| private IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> _tensor_slice_dict; | |||
| public SingleDeviceSaver(IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> tensor_slice_dict) | |||
| private IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> _tensor_slice_dict; | |||
| public SingleDeviceSaver(IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_slice_dict) | |||
| { | |||
| _tensor_slice_dict = tensor_slice_dict; | |||
| } | |||
| @@ -122,15 +30,15 @@ namespace Tensorflow.Checkpoint | |||
| { | |||
| _tensor_slice_dict = tensor_slice_dict.ToDictionary( | |||
| x => x.Key, x => x.Value.ToDictionary( | |||
| y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value)) | |||
| as IDictionary<string, Maybe<Tensor, SaveSpec>>); | |||
| y => y.Key, y => OneOf<Tensor, SaveSpec>.FromT0(y.Value)) | |||
| as IDictionary<string, OneOf<Tensor, SaveSpec>>); | |||
| } | |||
| public SingleDeviceSaver(IDictionary<string, IDictionary<string, SaveSpec>> tensor_slice_dict) | |||
| { | |||
| _tensor_slice_dict = tensor_slice_dict.ToDictionary( | |||
| x => x.Key, x => x.Value.ToDictionary( | |||
| y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value)) | |||
| as IDictionary<string, Maybe<Tensor, SaveSpec>>); | |||
| y => y.Key, y => OneOf<Tensor, SaveSpec>.FromT1(y.Value)) | |||
| as IDictionary<string, OneOf<Tensor, SaveSpec>>); | |||
| } | |||
| public Operation? save(Tensor file_prefix, CheckpointOptions? options = null) | |||
| { | |||
| @@ -149,7 +57,7 @@ namespace Tensorflow.Checkpoint | |||
| { | |||
| var slice_spec = slice.Key; | |||
| var maybe_tensor = slice.Value; | |||
| if(maybe_tensor.TryGet<SaveSpec>(out var spec)) | |||
| if(maybe_tensor.TryPickT1(out var spec, out var tensor)) | |||
| { | |||
| var tensor_value = spec.tensor; | |||
| if (tensor_value is not null) | |||
| @@ -161,7 +69,6 @@ namespace Tensorflow.Checkpoint | |||
| } | |||
| else | |||
| { | |||
| var tensor = maybe_tensor.GetValue<Tensor>(); | |||
| tensor_names.Add(checkpoint_key); | |||
| tensors.Add(tensor); | |||
| slice_specs.Add(slice_spec); | |||
| @@ -193,7 +100,7 @@ namespace Tensorflow.Checkpoint | |||
| var slice_spec = slice.Key; | |||
| var maybe_tensor = slice.Value; | |||
| // TODO: deal with other types. Currently only `SaveSpec` is allowed. | |||
| if(maybe_tensor.TryGet<SaveSpec>(out var spec)) | |||
| if(maybe_tensor.TryPickT1(out var spec, out var tensor)) | |||
| { | |||
| tensor_dtypes.Add(spec.dtype); | |||
| slice_specs.Add(spec.slice_spec); | |||
| @@ -201,7 +108,6 @@ namespace Tensorflow.Checkpoint | |||
| } | |||
| else | |||
| { | |||
| var tensor = maybe_tensor.GetValue<Tensor>(); | |||
| tensor_dtypes.Add(tensor.dtype); | |||
| slice_specs.Add(slice_spec); | |||
| tensor_names.Add(checkpoint_key); | |||
| @@ -254,7 +160,7 @@ namespace Tensorflow.Checkpoint | |||
| /// <param name="serialized_tensors"> A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. </param> | |||
| /// <param name="registered_savers"></param> | |||
| /// <param name="call_with_mapped_capture"></param> | |||
| public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors, | |||
| public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> serialized_tensors, | |||
| IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false) | |||
| { | |||
| _keys_to_restore_fn = new Dictionary<(string, string), RestoreFunc>(); | |||
| @@ -274,9 +180,9 @@ namespace Tensorflow.Checkpoint | |||
| { | |||
| restore_fn = new RestoreFunc(x => | |||
| { | |||
| if(x is IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>) | |||
| if(x is IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>) | |||
| { | |||
| return obj._restore_from_tensors(x as IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>); | |||
| return obj._restore_from_tensors(x as IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>); | |||
| } | |||
| throw new TypeError($"Expected `IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>` as input, got{x.GetType()}."); | |||
| }); | |||
| @@ -286,14 +192,14 @@ namespace Tensorflow.Checkpoint | |||
| { | |||
| var checkpoint_key = item.Key; | |||
| IDictionary<string, Tensor> spec_to_tensor; | |||
| if(item.Value.TryGet<Tensor>(out var t)) | |||
| if(item.Value.TryPickT0(out var t, out var dic)) | |||
| { | |||
| spec_to_tensor = new Dictionary<string, Tensor>(); | |||
| spec_to_tensor[""] = t; | |||
| } | |||
| else | |||
| { | |||
| spec_to_tensor = item.Value.GetValue<IDictionary<string, Tensor>>(); | |||
| spec_to_tensor = dic; | |||
| } | |||
| foreach(var spec in spec_to_tensor) | |||
| @@ -399,7 +305,7 @@ namespace Tensorflow.Checkpoint | |||
| IDictionary<string, Operation> restore_func() | |||
| { | |||
| Dictionary<RestoreFunc, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> restore_fn_inputs = new(); | |||
| Dictionary<RestoreFunc, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> restore_fn_inputs = new(); | |||
| Dictionary<RestoreFunc, int> restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count); | |||
| Dictionary<string, Operation> restore_ops = new(); | |||
| @@ -419,29 +325,29 @@ namespace Tensorflow.Checkpoint | |||
| var slice_spec = item.Key; | |||
| var tensor = item.Value; | |||
| var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; | |||
| var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>()); | |||
| var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>()); | |||
| if (!string.IsNullOrEmpty(slice_spec)) | |||
| { | |||
| if (!internal_dict.ContainsKey(checkpoint_key)) | |||
| { | |||
| Dictionary<string, Tensor> dict = new(); | |||
| dict[slice_spec] = tensor; | |||
| internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(dict); | |||
| internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT1(dict); | |||
| } | |||
| else | |||
| { | |||
| internal_dict[checkpoint_key].GetValue<IDictionary<string, Tensor>>()[slice_spec] = tensor; | |||
| internal_dict[checkpoint_key].AsT1[slice_spec] = tensor; | |||
| } | |||
| } | |||
| else | |||
| { | |||
| internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(tensor); | |||
| internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT0(tensor); | |||
| } | |||
| restore_fn_input_count[restore_fn]--; | |||
| if (restore_fn_input_count[restore_fn] == 0) | |||
| { | |||
| Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors = new(); | |||
| Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> restored_tensors = new(); | |||
| foreach(var input in restore_fn_inputs[restore_fn]) | |||
| { | |||
| restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; | |||
| @@ -519,7 +425,7 @@ namespace Tensorflow.Checkpoint | |||
| public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false) | |||
| { | |||
| Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new(); | |||
| Dictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new(); | |||
| foreach (var saveable in saveables) | |||
| { | |||
| var trackable = new SaveableCompatibilityConverter(saveable, new List<MySaveableObject>() { saveable }); | |||
| @@ -1,4 +1,5 @@ | |||
| using System; | |||
| using OneOf; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| @@ -61,13 +62,13 @@ public class CheckpointPosition | |||
| } | |||
| } | |||
| public (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) gather_ops_or_named_saveables() | |||
| public (List<Operation>, Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) gather_ops_or_named_saveables() | |||
| { | |||
| // skip the registered_saver | |||
| if (ObjectProto.Attributes is null || ObjectProto.Attributes.Count == 0) | |||
| { | |||
| return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(), | |||
| return (new List<Operation>(), new Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>(), | |||
| new List<CheckpointPosition>(), null); | |||
| } | |||
| @@ -75,7 +76,7 @@ public class CheckpointPosition | |||
| List<Operation> existing_restore_ops; | |||
| List<CheckpointPosition> positions = new(); | |||
| Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> named_saveables; | |||
| Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> named_saveables; | |||
| if (saveable_factories.Keys.Count == 1 && saveable_factories.Keys.First() == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) | |||
| { | |||
| (existing_restore_ops, named_saveables) = _create_serialize_to_tensor_saveable(saveable_factories); | |||
| @@ -109,8 +110,8 @@ public class CheckpointPosition | |||
| /// Creates a saveable using the _serialize_to_tensor method. | |||
| /// </summary> | |||
| /// <param name="saveable_factories"></param> | |||
| private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>) _create_serialize_to_tensor_saveable( | |||
| IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||
| private (List<Operation>, Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>) _create_serialize_to_tensor_saveable( | |||
| IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||
| { | |||
| string suffix = SaveableCompat.get_saveable_name(this.Trackable); | |||
| suffix = suffix ?? ""; | |||
| @@ -124,23 +125,23 @@ public class CheckpointPosition | |||
| var saveable = saveable_factories[TrackableUtils.SERIALIZE_TO_TENSORS_NAME](saveable_name); | |||
| // skip the cache. | |||
| Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> dict = new(); | |||
| Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> dict = new(); | |||
| dict[saveable_name] = saveable; | |||
| return (new List<Operation>(), dict); | |||
| } | |||
| private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>) _create_saveables_by_attribute_name( | |||
| IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||
| private (List<Operation>, Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>) _create_saveables_by_attribute_name( | |||
| IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| if(ObjectProto.Attributes is null) | |||
| { | |||
| return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>()); | |||
| return (new List<Operation>(), new Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>()); | |||
| } | |||
| List<Operation> existing_restore_ops = new(); | |||
| HashSet<string> created_compat_names = new(); | |||
| Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> named_saveables = new(); | |||
| Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> named_saveables = new(); | |||
| foreach (var serialized_tensor in ObjectProto.Attributes) | |||
| { | |||
| Operation existing_op; | |||
| @@ -172,12 +173,12 @@ public class CheckpointPosition | |||
| _checkpoint.UnusedAttributes.SetDefault(_proto_id, new List<string>()).Add(serialized_tensor.Name); | |||
| continue; | |||
| } | |||
| named_saveables[serialized_tensor.CheckpointKey] = saveable; | |||
| named_saveables[serialized_tensor.CheckpointKey] = saveable.Value; | |||
| } | |||
| return (existing_restore_ops, named_saveables); | |||
| } | |||
| private Maybe<BaseResourceVariable, MySaveableObject> _get_saveable_from_factory(IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories, | |||
| private OneOf<BaseResourceVariable, MySaveableObject>? _get_saveable_from_factory(IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> saveable_factories, | |||
| TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor serialized_tensor, HashSet<string> created_compat_names) | |||
| { | |||
| var expected_factory_name = serialized_tensor.Name; | |||
| @@ -221,7 +222,7 @@ public class CheckpointPosition | |||
| Queue<(CheckpointPosition, Trackable)> visit_queue = new(); | |||
| visit_queue.Enqueue((this, this.Trackable)); | |||
| List<Operation> restore_ops = new(); | |||
| Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> tensor_saveables = new(); | |||
| Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> tensor_saveables = new(); | |||
| List<CheckpointPosition> positions = new(); | |||
| CheckpointPosition current_position = null; | |||
| @@ -306,7 +307,7 @@ public class CheckpointPosition | |||
| } | |||
| } | |||
| private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) _single_restore() | |||
| private (List<Operation>, Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) _single_restore() | |||
| { | |||
| var trackable = this.Trackable; | |||
| trackable._maybe_initialize_trackable(); | |||
| @@ -318,7 +319,7 @@ public class CheckpointPosition | |||
| } | |||
| else | |||
| { | |||
| return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(), | |||
| return (new List<Operation>(), new Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>(), | |||
| new List<CheckpointPosition>(), null); | |||
| } | |||
| } | |||
| @@ -0,0 +1,13 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Eager | |||
| { | |||
| public class TangentInfo | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| public object Indices { get; set; } | |||
| public object Tangents { get; set; } | |||
| } | |||
| } | |||
| @@ -1,6 +1,8 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Framework.Models; | |||
| using Tensorflow.Graphs; | |||
| using Tensorflow.Train; | |||
| @@ -17,11 +19,13 @@ namespace Tensorflow.Functions | |||
| internal FuncGraph func_graph; | |||
| protected DelayedRewriteGradientFunctions _delayed_rewrite_functions; | |||
| protected Dictionary<string, string> _attrs; | |||
| protected FunctionSpec _function_spec; | |||
| protected FunctionSpec _pre_initialized_function_spec = null; | |||
| internal ForwardBackwardCall forward_backward; | |||
| public Tensor[] Inputs => func_graph.Inputs; | |||
| public Tensor[] CapturedInputs => func_graph.external_captures; | |||
| public string Name => _delayed_rewrite_functions.forward().Name; | |||
| public string Name => _delayed_rewrite_functions.Forward().Name; | |||
| public Tensor[] Outputs; | |||
| public Type ReturnType; | |||
| @@ -175,7 +179,13 @@ namespace Tensorflow.Functions | |||
| var (forward_function, args_with_tangents) = forward_backward.Forward(); | |||
| Tensors flat_outputs = null; | |||
| if (executing_eagerly) | |||
| { | |||
| flat_outputs = forward_function.Call(args_with_tangents); | |||
| } | |||
| else | |||
| { | |||
| flat_outputs = forward_function.Call(args_with_tangents); | |||
| } | |||
| forward_backward.Record(flat_outputs); | |||
| return flat_outputs; | |||
| } | |||
| @@ -186,7 +196,7 @@ namespace Tensorflow.Functions | |||
| { | |||
| g = ops.get_default_graph(); | |||
| } | |||
| _delayed_rewrite_functions.forward().AddToGraph(g); | |||
| _delayed_rewrite_functions.Forward().AddToGraph(g); | |||
| } | |||
| public void SetExternalCaptures(IEnumerable<Tensor> captures) | |||
| @@ -196,8 +206,60 @@ namespace Tensorflow.Functions | |||
| ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | |||
| { | |||
| var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | |||
| return new ForwardBackwardCall(functions, args, tape_watching: true); | |||
| TangentInfo input_tangents; | |||
| if (executing_eagerly) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| else | |||
| { | |||
| input_tangents = new TangentInfo(); | |||
| } | |||
| if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER) | |||
| { | |||
| if(input_tangents.Indices is not null || executing_eagerly) | |||
| { | |||
| var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | |||
| return new ForwardBackwardCall(functions, args, tape_watching: true); | |||
| } | |||
| else | |||
| { | |||
| return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: true); | |||
| } | |||
| } | |||
| else if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| // TODO(Rinne): add arg "input_tagents" for ForwardBackwardCall. | |||
| return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: false); | |||
| } | |||
| internal void _set_function_spec(FunctionSpec spec) | |||
| { | |||
| _function_spec = null; | |||
| _pre_initialized_function_spec = spec; | |||
| _initialize_function_spec(); | |||
| } | |||
| internal void _initialize_function_spec() | |||
| { | |||
| if(_pre_initialized_function_spec is null) | |||
| { | |||
| return; | |||
| } | |||
| Debug.Assert(_function_spec is null, "already initialized"); | |||
| var spec = _pre_initialized_function_spec; | |||
| //var args = spec.Fullargspec.DictValue.Fields["args"]; | |||
| // TODO(Rinne): self.structured_input_signature | |||
| _function_spec = new FunctionSpec() | |||
| { | |||
| Fullargspec = spec.Fullargspec, | |||
| IsMethod = spec.IsMethod, | |||
| InputSignature = spec.InputSignature | |||
| }; | |||
| } | |||
| public override string ToString() | |||
| @@ -5,6 +5,8 @@ using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Contexts; | |||
| using Tensorflow.Graphs; | |||
| using Tensorflow.Operations; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Functions | |||
| @@ -14,7 +16,10 @@ namespace Tensorflow.Functions | |||
| public int _num_outputs; | |||
| FuncGraph _func_graph; | |||
| FunctionDef _definition; | |||
| Tensor[] _func_graph_outputs; | |||
| public string Name => _func_graph.FuncName; | |||
| public DataType[] OutputTypes { get; protected set; } | |||
| public Shape[] OutputShapes { get; protected set; } | |||
| public FunctionDef Definition | |||
| { | |||
| get | |||
| @@ -36,27 +41,69 @@ namespace Tensorflow.Functions | |||
| var operations = graph.get_operations().Where(x => !input_ops.Contains(x.op)) | |||
| .Select(x => x as Operation).ToArray(); | |||
| var output_names = new string[0]; | |||
| OutputShapes = outputs.Select(x => x.shape).ToArray(); | |||
| OutputTypes = outputs.Select(x => x.dtype.as_datatype_enum()).ToArray(); | |||
| _func_graph = new FuncGraph(graph, name, attrs); | |||
| _func_graph_outputs = new List<Tensor>(outputs).ToArray(); | |||
| _func_graph.ToGraph(operations, inputs, outputs, output_names); | |||
| } | |||
| public Tensors Call(Tensors args) | |||
| { | |||
| // TODO(Rinne): Add arg `CancellationManager`. | |||
| // TODO(Rinne): Check the arg length. | |||
| var function_call_options = tf.Context.FunctionCallOptions; | |||
| string config; | |||
| if (string.IsNullOrEmpty(function_call_options.config_proto_serialized())) | |||
| { | |||
| config = function_utils.get_disabled_rewriter_config(); | |||
| } | |||
| else | |||
| { | |||
| config = function_call_options.config_proto_serialized(); | |||
| } | |||
| // TODO(Rinne): executor_type | |||
| var executing_eagerly = tf.Context.executing_eagerly(); | |||
| var attrs = new object[] | |||
| { | |||
| "executor_type", "", | |||
| "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() | |||
| }; | |||
| var results = tf.Runner.TFE_Execute(tf.Context, | |||
| Tensor[] outputs; | |||
| if (executing_eagerly) | |||
| { | |||
| outputs = tf.Runner.TFE_Execute(tf.Context, | |||
| tf.Context.DeviceName, | |||
| _func_graph.FuncName, | |||
| args, | |||
| attrs, | |||
| _num_outputs); | |||
| return results; | |||
| } | |||
| else | |||
| { | |||
| tf.GradientTape().stop_recording(); | |||
| outputs = functional_ops.partitioned_call(args, this, OutputTypes, | |||
| executing_eagerly, config, ""); | |||
| } | |||
| foreach(var (i, func_graph_output) in enumerate(_func_graph_outputs)) | |||
| { | |||
| handle_data_util.copy_handle_data(func_graph_output, outputs[i]); | |||
| } | |||
| if (executing_eagerly) | |||
| { | |||
| return outputs; | |||
| } | |||
| else | |||
| { | |||
| foreach(var (i, shape) in enumerate(OutputShapes)) | |||
| { | |||
| outputs[i].shape = shape; | |||
| } | |||
| return outputs; | |||
| } | |||
| } | |||
| public void AddToGraph(Graph g = null) | |||
| @@ -9,16 +9,46 @@ namespace Tensorflow | |||
| #pragma warning disable CS0169 // The field 'Function._handle' is never used | |||
| private IntPtr _handle; | |||
| #pragma warning restore CS0169 // The field 'Function._handle' is never used | |||
| protected Func<Tensors, Tensors> _function; | |||
| protected ConcreteFunction _concrete_variable_creation_fn; | |||
| protected bool _auto_graph; | |||
| public string Name { get; set; } | |||
| public Function() | |||
| public Function(Func<Tensors, Tensors> function, | |||
| string name, bool auto_graph = true) | |||
| { | |||
| _function = function; | |||
| Name = name; | |||
| _auto_graph = auto_graph; | |||
| } | |||
| public virtual Tensors Apply(Tensors inputs) | |||
| { | |||
| if (_run_functions_eagerly()) | |||
| { | |||
| return _function(inputs); | |||
| } | |||
| var result = _call(inputs); | |||
| return result; | |||
| } | |||
| public Function(string name) | |||
| protected virtual Tensors _call(Tensors inputs) | |||
| { | |||
| Name = name; | |||
| _initialize(); | |||
| return _concrete_variable_creation_fn.CallFlat(inputs, | |||
| _concrete_variable_creation_fn.CapturedInputs); | |||
| } | |||
| protected virtual bool _run_functions_eagerly() | |||
| { | |||
| return false; | |||
| } | |||
| private void _initialize() | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -15,11 +15,11 @@ namespace Tensorflow.Functions | |||
| /// </summary> | |||
| public abstract class TapeGradientFunctions | |||
| { | |||
| string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"; | |||
| string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"; | |||
| string _FORWARD_PREFIX = "__forward_"; | |||
| string _BACKWARD_PREFIX = "__backward_"; | |||
| string _INFERENCE_PREFIX = "__inference_"; | |||
| protected string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"; | |||
| protected string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"; | |||
| protected string _FORWARD_PREFIX = "__forward_"; | |||
| protected string _BACKWARD_PREFIX = "__backward_"; | |||
| protected string _INFERENCE_PREFIX = "__inference_"; | |||
| protected FuncGraph _func_graph; | |||
| protected EagerDefinedFunction _forward; | |||
| @@ -35,8 +35,9 @@ namespace Tensorflow.Functions | |||
| _func_graph = func_graph; | |||
| } | |||
| public EagerDefinedFunction Forward(Tensors inference_args) | |||
| public virtual EagerDefinedFunction Forward(Tensors inference_args, Tensors input_tangents = null) | |||
| { | |||
| // TODO(Rinne): add input_tangents arg. | |||
| return ForwardAndBackwardFunctions(inference_args); | |||
| } | |||
| @@ -45,8 +46,9 @@ namespace Tensorflow.Functions | |||
| /// </summary> | |||
| /// <param name="flat_outputs"></param> | |||
| /// <param name="inference_args"></param> | |||
| public void Record(Tensors flat_outputs, Tensors inference_args) | |||
| public virtual void Record(Tensors flat_outputs, Tensors inference_args) | |||
| { | |||
| // TODO(Rinne): add arg `input_tagents`. | |||
| var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs); | |||
| tf.Runner.RecordGradient(_forward.Name, inference_args, new object[0], to_record, | |||
| getBackwardFunction: backward_function); | |||
| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Operations; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Variables; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Functions | |||
| @@ -5,16 +5,13 @@ using Tensorflow.Graphs; | |||
| namespace Tensorflow.Functions | |||
| { | |||
| public class DelayedRewriteGradientFunctions | |||
| public class DelayedRewriteGradientFunctions: TapeGradientFunctions | |||
| { | |||
| static readonly string _INFERENCE_PREFIX = "__inference_"; | |||
| static readonly string _BACKWARD_PREFIX = "__backward_"; | |||
| static readonly string _FORWARD_PREFIX = "__forward_"; | |||
| FuncGraph _func_graph; | |||
| EagerDefinedFunction _inference_function; | |||
| Dictionary<string, string> _attrs; | |||
| int _num_inference_outputs; | |||
| public DelayedRewriteGradientFunctions(FuncGraph func_graph, Dictionary<string, string> attrs) | |||
| :base(func_graph, false) | |||
| { | |||
| _func_graph= func_graph; | |||
| _inference_function = new EagerDefinedFunction(_inference_name(_func_graph.Name), | |||
| @@ -23,7 +20,7 @@ namespace Tensorflow.Functions | |||
| _num_inference_outputs = _func_graph.Outputs.Length; | |||
| } | |||
| public EagerDefinedFunction forward(Tensors inference_args = null, Tensors input_tangents = null) | |||
| public override EagerDefinedFunction Forward(Tensors inference_args = null, Tensors input_tangents = null) | |||
| { | |||
| if(input_tangents is not null) | |||
| { | |||
| @@ -33,7 +30,23 @@ namespace Tensorflow.Functions | |||
| return _inference_function; | |||
| } | |||
| private static string _inference_name(string name) | |||
| public override void Record(Tensors flat_outputs, Tensors inference_args) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| throw new NotImplementedException(); | |||
| base.Record(flat_outputs, inference_args); | |||
| } | |||
| //private (BackwardFunction, Tensors) _backward(Tensors outputs) | |||
| //{ | |||
| // Tensor[] backward_function(Tensor[] grads, long[] unneeded_gradients) | |||
| // { | |||
| // var call_op = outputs[0].op; | |||
| // } | |||
| //} | |||
| private string _inference_name(string name) | |||
| { | |||
| return $"{_INFERENCE_PREFIX}{name}_{ops.uid()}"; | |||
| } | |||
| @@ -25,6 +25,11 @@ namespace Tensorflow | |||
| { | |||
| public class gradients_util | |||
| { | |||
| // Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are | |||
| // unfortunately too slow to use here. | |||
| public static int POSSIBLE_GRADIENT_TYPES_NONE = 0; | |||
| public static int POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1; | |||
| public static int POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2; | |||
| public static Tensor[] _GradientsHelper(Tensor[] ys, | |||
| Tensor[] xs, | |||
| Tensor[] grad_ys = null, | |||
| @@ -129,6 +129,7 @@ namespace Tensorflow | |||
| protected Graph outer_graph; | |||
| public Graph OuterGraph => outer_graph; | |||
| public Dictionary<string, EagerDefinedFunction> Functions => _functions; | |||
| public SafeGraphHandle c_graph => _handle; | |||
| public Graph() | |||
| { | |||
| @@ -208,5 +208,9 @@ namespace Tensorflow | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, SafeStatusHandle status); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data); | |||
| } | |||
| } | |||
| @@ -14,10 +14,14 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using Google.Protobuf; | |||
| using Google.Protobuf.WellKnownTypes; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Framework; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.Operations; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| @@ -25,6 +29,72 @@ namespace Tensorflow | |||
| { | |||
| public class functional_ops | |||
| { | |||
| public static Tensor[] partitioned_call(Tensors args, EagerDefinedFunction f, DataType[] tout, | |||
| bool executing_eagerly, string config, string executor_type) | |||
| { | |||
| if (tout is null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| if (config is null) | |||
| { | |||
| config = function_utils.get_disabled_rewriter_config(); | |||
| } | |||
| if (executor_type is null) | |||
| { | |||
| executor_type = ""; | |||
| } | |||
| if (executing_eagerly) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| var converted_args = args.Select(x => ops.convert_to_tensor(x)).ToArray(); | |||
| AttrValue tin_attr = new() | |||
| { | |||
| List = new AttrValue.Types.ListValue() | |||
| }; | |||
| tin_attr.List.Type.AddRange(args.Select(x => x.dtype.as_datatype_enum())); | |||
| AttrValue tout_attr = new() | |||
| { | |||
| List = new AttrValue.Types.ListValue() | |||
| }; | |||
| tout_attr.List.Type.AddRange(tout); | |||
| AttrValue func_attr = new() | |||
| { | |||
| Func = new NameAttrList() | |||
| }; | |||
| func_attr.Func.Name = f.Name; | |||
| AttrValue executor_type_attr = new AttrValue() | |||
| { | |||
| S = tf.compat.as_bytes(executor_type) | |||
| }; | |||
| AttrValue config_proto = new AttrValue() | |||
| { | |||
| S = ByteString.CopyFromUtf8(executor_type) | |||
| }; | |||
| var graph = ops.get_default_graph(); | |||
| f.AddToGraph(graph); | |||
| // TODO(Rinne): complete it with `f.stateful` | |||
| var op_name = "PartitionedCall"; | |||
| string xla_compile_attr = "_XlaMustCompile"; | |||
| Dictionary<string, AttrValue> op_attrs = new(); | |||
| op_attrs["Tin"] = tin_attr; | |||
| op_attrs["Tout"] = tout_attr; | |||
| op_attrs["f"] = func_attr; | |||
| op_attrs["config_proto"] = config_proto; | |||
| op_attrs["executor_type"] = executor_type_attr; | |||
| // TODO(Rinne): deal with `f.definition`. | |||
| var op = graph.create_op(op_name, args, tout.Select(x => x.as_tf_dtype()).ToArray(), | |||
| name: op_name, attrs: op_attrs); | |||
| var outputs = op.outputs; | |||
| // TODO(Rinne): deal with `f.graph`. | |||
| return outputs; | |||
| } | |||
| public static Tensor scan( | |||
| Func<Tensor, Tensor, Tensor> fn, | |||
| Tensor elems, | |||
| @@ -0,0 +1,83 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using System.Xml.Linq; | |||
| using Tensorflow.Contexts; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Functions; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Operations | |||
| { | |||
| public class gen_functional_ops | |||
| { | |||
| public static Tensor[] partitioned_call(Tensors args, TF_DataType[] tout, EagerDefinedFunction f, | |||
| string config = "", string config_proto = "", string executor_type = "", string name = null) | |||
| { | |||
| var ctx = tf.Context; | |||
| if (ctx.executing_eagerly()) | |||
| { | |||
| try | |||
| { | |||
| return tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("PartitionedCall", name, | |||
| args, tout, f, config, config_proto, executor_type)); | |||
| } | |||
| catch (Exception) | |||
| { | |||
| } | |||
| } | |||
| if (config is null) | |||
| { | |||
| config = ""; | |||
| } | |||
| if (config_proto is null) | |||
| { | |||
| config_proto = ""; | |||
| } | |||
| if (executor_type is null) | |||
| { | |||
| executor_type = ""; | |||
| } | |||
| Dictionary<string, object> kwargs = new(); | |||
| kwargs["args"] = args; | |||
| kwargs["Tout"] = tout; | |||
| kwargs["f"] = f; | |||
| kwargs["config"] = config; | |||
| kwargs["config_proto"] = config_proto; | |||
| kwargs["executor_type"] = executor_type; | |||
| var output = tf.OpDefLib._apply_op_helper("PartitionedCall", | |||
| name, kwargs); | |||
| var result = output.outputs; | |||
| if (execute.must_record_gradient()) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| return result; | |||
| } | |||
| public static Tensor[] partitioned_call_eager_fallback(Tensors args, TF_DataType[] tout, EagerDefinedFunction f, | |||
| string config, string config_proto, string executor_type, string name, Context ctx) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| throw new NotImplementedException(); | |||
| if(config is null) | |||
| { | |||
| config = ""; | |||
| } | |||
| if(config_proto is null) | |||
| { | |||
| config_proto = ""; | |||
| } | |||
| if(executor_type is null) | |||
| { | |||
| executor_type = ""; | |||
| } | |||
| object[] attrs = new object[] | |||
| { | |||
| }; | |||
| } | |||
| } | |||
| } | |||
| @@ -1,7 +1,9 @@ | |||
| using System; | |||
| using Google.Protobuf; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Eager; | |||
| using static Tensorflow.CppShapeInferenceResult.Types; | |||
| namespace Tensorflow.Operations | |||
| { | |||
| @@ -11,18 +13,31 @@ namespace Tensorflow.Operations | |||
| { | |||
| if(target_t.dtype == dtypes.resource || target_t.dtype == dtypes.variant) | |||
| { | |||
| SafeTensorHandle handle_data; | |||
| HandleData handle_data; | |||
| if(source_t is EagerTensor) | |||
| { | |||
| handle_data = source_t.Handle; | |||
| handle_data = source_t.HandleData; | |||
| } | |||
| else | |||
| { | |||
| handle_data = ops.get_resource_handle_data(source_t); | |||
| } | |||
| throw new NotImplementedException(); | |||
| //if(handle_data is not null && handle_data.) | |||
| if(handle_data is not null && handle_data.IsSet && handle_data.ShapeAndType is not null | |||
| && handle_data.ShapeAndType.Count > 0) | |||
| { | |||
| set_handle_data(target_t, handle_data); | |||
| } | |||
| } | |||
| } | |||
| public static void set_handle_data(Tensor target_t, HandleData handle_data) | |||
| { | |||
| if(target_t is EagerTensor) | |||
| { | |||
| target_t.HandleData = handle_data; | |||
| return; | |||
| } | |||
| c_api.SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), handle_data.ToByteArray()); | |||
| } | |||
| } | |||
| } | |||
| @@ -39,7 +39,7 @@ namespace Tensorflow | |||
| public static bool is_resource_variable(IVariableV1 var) | |||
| { | |||
| return var is ResourceVariable; | |||
| return var is BaseResourceVariable; | |||
| } | |||
| public static bool is_resource_variable(Trackable var) | |||
| @@ -231,5 +231,21 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| } | |||
| public static void _maybe_set_handle_data(TF_DataType dtype, Tensor handle, Tensor tensor) | |||
| { | |||
| if(dtype == dtypes.variant) | |||
| { | |||
| var handle_data = get_eager_safe_handle_data(handle); | |||
| if(handle_data.IsSet && handle_data.ShapeAndType.Count > 1) | |||
| { | |||
| tensor.HandleData = new HandleData() | |||
| { | |||
| IsSet = true | |||
| }; | |||
| tensor.HandleData.ShapeAndType.AddRange(handle_data.ShapeAndType.Skip(1)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -479,7 +479,7 @@ namespace Tensorflow { | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public pbc::RepeatedField<global::Tensorflow.CppShapeInferenceResult.Types.HandleShapeAndType> ShapeAndType { | |||
| get { return shapeAndType_; } | |||
| get { return shapeAndType_; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| @@ -277,15 +277,15 @@ namespace Tensorflow { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public SavedObject() { | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public SavedObject() { | |||
| OnConstruction(); | |||
| } | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public SavedObject(SavedObject other) : this() { | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public SavedObject(SavedObject other) : this() { | |||
| children_ = other.children_.Clone(); | |||
| dependencies_ = other.dependencies_.Clone(); | |||
| slotVariables_ = other.slotVariables_.Clone(); | |||
| @@ -329,7 +329,9 @@ namespace Tensorflow { | |||
| public const int ChildrenFieldNumber = 1; | |||
| private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_children_codec | |||
| = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); | |||
| private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); | |||
| private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_dependencies_codec | |||
| = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); | |||
| private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); | |||
| private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> dependencies_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); | |||
| /// <summary> | |||
| /// Objects which this object depends on: named edges in the dependency | |||
| @@ -501,7 +503,8 @@ namespace Tensorflow { | |||
| return true; | |||
| } | |||
| if(!children_.Equals(other.children_)) return false; | |||
| if(!slotVariables_.Equals(other.slotVariables_)) return false; | |||
| if (!dependencies_.Equals(other.dependencies_)) return false; | |||
| if (!slotVariables_.Equals(other.slotVariables_)) return false; | |||
| if (!object.Equals(UserObject, other.UserObject)) return false; | |||
| if (!object.Equals(Asset, other.Asset)) return false; | |||
| if (!object.Equals(Function, other.Function)) return false; | |||
| @@ -519,6 +522,7 @@ namespace Tensorflow { | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| hash ^= children_.GetHashCode(); | |||
| hash ^= dependencies_.GetHashCode(); | |||
| hash ^= slotVariables_.GetHashCode(); | |||
| if (kindCase_ == KindOneofCase.UserObject) hash ^= UserObject.GetHashCode(); | |||
| if (kindCase_ == KindOneofCase.Asset) hash ^= Asset.GetHashCode(); | |||
| @@ -544,6 +548,7 @@ namespace Tensorflow { | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| children_.WriteTo(output, _repeated_children_codec); | |||
| children_.WriteTo(output, _repeated_dependencies_codec); | |||
| slotVariables_.WriteTo(output, _repeated_slotVariables_codec); | |||
| if (kindCase_ == KindOneofCase.UserObject) { | |||
| output.WriteRawTag(34); | |||
| @@ -587,6 +592,7 @@ namespace Tensorflow { | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| size += children_.CalculateSize(_repeated_children_codec); | |||
| size += children_.CalculateSize(_repeated_dependencies_codec); | |||
| size += slotVariables_.CalculateSize(_repeated_slotVariables_codec); | |||
| if (kindCase_ == KindOneofCase.UserObject) { | |||
| size += 1 + pb::CodedOutputStream.ComputeMessageSize(UserObject); | |||
| @@ -619,7 +625,7 @@ namespace Tensorflow { | |||
| return size; | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| //[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public void MergeFrom(SavedObject other) { | |||
| if (other == null) { | |||
| return; | |||
| @@ -682,7 +688,7 @@ namespace Tensorflow { | |||
| _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| //[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| @@ -692,9 +698,10 @@ namespace Tensorflow { | |||
| break; | |||
| case 10: { | |||
| children_.AddEntriesFrom(input, _repeated_children_codec); | |||
| dependencies_.AddRange(children_.Except(dependencies_)); | |||
| break; | |||
| } | |||
| case 26: { | |||
| case 26: { | |||
| slotVariables_.AddEntriesFrom(input, _repeated_slotVariables_codec); | |||
| break; | |||
| } | |||
| @@ -109,7 +109,12 @@ https://tensorflownet.readthedocs.io</Description> | |||
| <ItemGroup> | |||
| <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> | |||
| <PackageReference Include="Newtonsoft.Json" Version="13.0.2" /> | |||
| <PackageReference Include="OneOf" Version="3.0.223" /> | |||
| <PackageReference Include="Protobuf.Text" Version="0.6.1" /> | |||
| <PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| <ProjectReference Include="..\..\Tensorflow.Common\Tensorflow.Common.csproj" /> | |||
| </ItemGroup> | |||
| </Project> | |||
| @@ -87,6 +87,7 @@ namespace Tensorflow | |||
| public object Tag { get; set; } | |||
| protected new SafeTensorHandle _handle; | |||
| public virtual SafeTensorHandle Handle => _handle; | |||
| public Tensorflow.CppShapeInferenceResult.Types.HandleData HandleData { get; internal set; } | |||
| protected SafeEagerTensorHandle _eagerTensorHandle; | |||
| /// <summary> | |||
| @@ -14,18 +14,19 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using OneOf; | |||
| using Tensorflow.Checkpoint; | |||
| namespace Tensorflow | |||
| { | |||
| public class MySaveableObject | |||
| { | |||
| protected Maybe<Tensor, BaseResourceVariable> _op; | |||
| protected OneOf<Tensor, BaseResourceVariable> _op; | |||
| public Tensor op | |||
| { | |||
| get | |||
| { | |||
| if(_op.TryGet<Tensor>(out var tensor)) | |||
| if(_op.TryPickT0(out var tensor, out var _)) | |||
| { | |||
| return tensor; | |||
| } | |||
| @@ -43,7 +44,7 @@ namespace Tensorflow | |||
| { | |||
| get | |||
| { | |||
| if (_op.TryGet<BaseResourceVariable>(out var v)) | |||
| if (_op.TryPickT1(out var v, out var _)) | |||
| { | |||
| return v; | |||
| } | |||
| @@ -25,11 +25,32 @@ namespace Tensorflow.Training.Saving.SavedModel | |||
| /// <param name="saved_concrete_function"></param> | |||
| /// <param name="concrete_functions"></param> | |||
| /// <returns></returns> | |||
| public static ConcreteFunction recreate_function(SavedFunction saved_concrete_function, | |||
| public static Function recreate_function(SavedFunction saved_function, | |||
| IDictionary<string, ConcreteFunction> concrete_functions) | |||
| { | |||
| var function_spec = _deserialize_function_spec_as_nonmethod(saved_concrete_function.FunctionSpec); | |||
| return null; | |||
| var function_spec = _deserialize_function_spec_as_nonmethod(saved_function.FunctionSpec); | |||
| List<ConcreteFunction> concrete_function_objects = new(); | |||
| foreach(var concrete_function_name in saved_function.ConcreteFunctions) | |||
| { | |||
| concrete_function_objects.Add(concrete_functions[concrete_function_name]); | |||
| } | |||
| foreach(var cf in concrete_function_objects) | |||
| { | |||
| cf._set_function_spec(function_spec); | |||
| } | |||
| foreach(var function_name in saved_function.ConcreteFunctions) | |||
| { | |||
| var function = concrete_functions[function_name]; | |||
| if(_concrete_function_callable_with(function, null, false)) | |||
| { | |||
| return new RestoredFunction(null, function, "function_from_deserialization"); | |||
| } | |||
| } | |||
| return new RestoredFunction(x => x, new ConcreteFunction(x => x, TF_DataType.TF_FLOAT), "function_return_itself"); | |||
| //throw new ValueError("Unexpected runtime behavior, please submit an issue to " + | |||
| // "https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
| } | |||
| public static Dictionary<string, ConcreteFunction> load_function_def_library(FunctionDefLibrary library, | |||
| @@ -385,5 +406,31 @@ namespace Tensorflow.Training.Saving.SavedModel | |||
| JitCompile = function_spec_proto.JitCompile | |||
| }; | |||
| } | |||
| private static Tensors _call_concrete_function(ConcreteFunction function, Tensors inputs) | |||
| { | |||
| // TODO(Rinne): var expected_structure = function.func_graph.structured_input_signature | |||
| return function.CallFlat(inputs, function.CapturedInputs); | |||
| } | |||
| private static bool _concrete_function_callable_with(ConcreteFunction function, Tensors inputs, bool allow_conversion) | |||
| { | |||
| // TODO(Rinne): revise it. | |||
| return true; | |||
| } | |||
| } | |||
| public class RestoredFunction : Function | |||
| { | |||
| public RestoredFunction(Func<Tensors, Tensors> function, ConcreteFunction concrete_function, | |||
| string name, bool auto_graph = true): base(function, name, auto_graph) | |||
| { | |||
| _concrete_variable_creation_fn = concrete_function; | |||
| } | |||
| protected override bool _run_functions_eagerly() | |||
| { | |||
| return false; | |||
| } | |||
| } | |||
| } | |||
| @@ -14,6 +14,7 @@ using Tensorflow.Variables; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.Training.Saving.SavedModel; | |||
| using Tensorflow.Trackables; | |||
| using OneOf; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -44,6 +45,8 @@ namespace Tensorflow | |||
| _asset_file_def = meta_graph.AssetFileDef; | |||
| _operation_attributes = meta_graph.GraphDef.Node.ToDictionary(x => x.Name, x => x.Attr); | |||
| _proto = object_graph_proto; | |||
| // Debug(Rinne) | |||
| var temp = _proto.ToString(); | |||
| _export_dir = export_dir; | |||
| // TODO: `this._concrete_functions` and `this._restored_concrete_functions` | |||
| _concrete_functions = function_deserialization.load_function_def_library( | |||
| @@ -259,9 +262,9 @@ namespace Tensorflow | |||
| /// </summary> | |||
| /// <param name="proto"></param> | |||
| /// <returns></returns> | |||
| private Dictionary<Maybe<string, int>, int> _get_node_dependencies(SavedObject proto) | |||
| private Dictionary<OneOf<string, int>, int> _get_node_dependencies(SavedObject proto) | |||
| { | |||
| Dictionary<Maybe<string, int>, int> dependencies = new(); | |||
| Dictionary<OneOf<string, int>, int> dependencies = new(); | |||
| foreach(var refer in proto.Dependencies) | |||
| { | |||
| dependencies[refer.LocalName] = refer.NodeId; | |||
| @@ -375,11 +378,6 @@ namespace Tensorflow | |||
| // Re-create everything. | |||
| foreach (var (node_id, proto) in _iter_all_nodes()) | |||
| { | |||
| if(node_id == 45) | |||
| { | |||
| // TODelete | |||
| Console.WriteLine(); | |||
| } | |||
| if (nodes.ContainsKey(node_id)) | |||
| { | |||
| continue; | |||
| @@ -474,7 +472,7 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| private void _setup_function_captures(string concrete_function_name, IDictionary<Maybe<string, int>, Trackable> nodes) | |||
| private void _setup_function_captures(string concrete_function_name, IDictionary<OneOf<string, int>, Trackable> nodes) | |||
| { | |||
| if (_restored_concrete_functions.Contains(concrete_function_name)) | |||
| { | |||
| @@ -509,6 +507,11 @@ namespace Tensorflow | |||
| /// <param name="node_id"></param> | |||
| private void _add_object_graph_edges(SavedObject proto, int node_id) | |||
| { | |||
| // Debug(Rinne) | |||
| if(node_id == 1) | |||
| { | |||
| Console.WriteLine(); | |||
| } | |||
| var obj = _nodes[node_id]; | |||
| var setter = _node_setters[node_id]; | |||
| @@ -549,8 +552,13 @@ namespace Tensorflow | |||
| private (Trackable, Action<object, object, object>) _recreate(SavedObject proto, int node_id, IDictionary<int, Trackable> nodes) | |||
| { | |||
| // skip the registered classes. | |||
| if(node_id == 16) | |||
| { | |||
| // Debug(Rinne) | |||
| Console.WriteLine(); | |||
| } | |||
| Dictionary<Maybe<string, int>, Trackable> dependencies = new(); | |||
| Dictionary<OneOf<string, int>, Trackable> dependencies = new(); | |||
| foreach(var item in _get_node_dependencies(proto)) | |||
| { | |||
| dependencies[item.Key] = nodes[item.Value]; | |||
| @@ -571,7 +579,7 @@ namespace Tensorflow | |||
| /// <param name="proto"></param> | |||
| /// <param name="node_id"></param> | |||
| /// <param name="dependencies"></param> | |||
| private (Trackable, Action<object, object, object>) _recreate_default(SavedObject proto, int node_id, IDictionary<Maybe<string, int>, Trackable> dependencies) | |||
| private (Trackable, Action<object, object, object>) _recreate_default(SavedObject proto, int node_id, IDictionary<OneOf<string, int>, Trackable> dependencies) | |||
| { | |||
| return proto.KindCase switch | |||
| { | |||
| @@ -637,10 +645,10 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| private (ConcreteFunction, Action<object, object, object>) _recreate_function(SavedFunction proto, | |||
| Dictionary<Maybe<string, int>, Trackable> dependencies) | |||
| private (Function, Action<object, object, object>) _recreate_function(SavedFunction proto, | |||
| Dictionary<OneOf<string, int>, Trackable> dependencies) | |||
| { | |||
| var fn = function_deserialization.recreate_function(proto, null); | |||
| var fn = function_deserialization.recreate_function(proto, _concrete_functions); | |||
| foreach (var name in proto.ConcreteFunctions) | |||
| { | |||
| _setup_function_captures(name, dependencies); | |||
| @@ -649,7 +657,7 @@ namespace Tensorflow | |||
| } | |||
| private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, | |||
| IDictionary<Maybe<string, int>, Trackable> dependencies) | |||
| IDictionary<OneOf<string, int>, Trackable> dependencies) | |||
| { | |||
| var fn = function_deserialization.setup_bare_concrete_function(proto, _concrete_functions); | |||
| _setup_function_captures(proto.ConcreteFunctionName, dependencies); | |||
| @@ -14,6 +14,7 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using OneOf; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| @@ -174,7 +175,7 @@ namespace Tensorflow | |||
| full_name = name + "_" + attr; | |||
| } | |||
| var op = factory(full_name); | |||
| if(op.TryGet<BaseResourceVariable>(out var variable)) | |||
| if(op.TryPickT0(out var variable, out var saveable)) | |||
| { | |||
| foreach (var v in saveable_objects_for_op(variable as Trackable, variable.Name)) | |||
| { | |||
| @@ -183,7 +184,6 @@ namespace Tensorflow | |||
| } | |||
| else | |||
| { | |||
| var saveable = op.GetValue<MySaveableObject>(); | |||
| foreach (var v in saveable_objects_for_op(saveable, saveable.name)) | |||
| { | |||
| yield return v; | |||
| @@ -252,11 +252,11 @@ namespace Tensorflow | |||
| return names_to_saveables; | |||
| } | |||
| public static IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_objects_from_trackable(Trackable obj) | |||
| public static IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> saveable_objects_from_trackable(Trackable obj) | |||
| { | |||
| // skip the process of type `PythonState` | |||
| Maybe<BaseResourceVariable, MySaveableObject> create_saveable(string name = "") | |||
| OneOf<BaseResourceVariable, MySaveableObject> create_saveable(string name = "") | |||
| { | |||
| // skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`. | |||
| var tensor_dict = obj.serialize_to_tensors(); | |||
| @@ -272,14 +272,14 @@ namespace Tensorflow | |||
| string spec_name = name + TrackableUtils.escape_local_name(tensor_name); | |||
| IDictionary<string, Tensor> internal_dict; | |||
| if (maybe_tensor.TryGet<Tensor>(out var tensor)) | |||
| if (maybe_tensor.TryPickT0(out var tensor, out var dic)) | |||
| { | |||
| internal_dict = new Dictionary<string, Tensor>(); | |||
| internal_dict[""] = tensor; | |||
| } | |||
| else | |||
| { | |||
| internal_dict = maybe_tensor.GetValue<IDictionary<string, Tensor>>(); | |||
| internal_dict = dic; | |||
| } | |||
| foreach (var item in internal_dict) | |||
| @@ -292,7 +292,7 @@ namespace Tensorflow | |||
| if (trackable_has_serialize_to_tensor(obj)) | |||
| { | |||
| Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> res = new(); | |||
| Dictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> res = new(); | |||
| res[TrackableUtils.SERIALIZE_TO_TENSORS_NAME] = create_saveable; | |||
| return res; | |||
| } | |||
| @@ -316,9 +316,9 @@ namespace Tensorflow | |||
| /// Converts a list of SaveableObjects to a tensor dictionary. | |||
| /// </summary> | |||
| /// <param name="saveables"></param> | |||
| public static Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> saveable_object_to_tensor_dict(IList<MySaveableObject> saveables) | |||
| public static Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> saveable_object_to_tensor_dict(IList<MySaveableObject> saveables) | |||
| { | |||
| Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new(); | |||
| Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> tensor_dict = new(); | |||
| foreach (var saveable in saveables) | |||
| { | |||
| foreach (var spec in saveable.specs) | |||
| @@ -328,7 +328,7 @@ namespace Tensorflow | |||
| var slice_spec = convert_to_string(spec.slice_spec); | |||
| if (!string.IsNullOrEmpty(slice_spec)) | |||
| { | |||
| tensor_dict.SetDefault(name, new Dictionary<string, Tensor>()).GetValue<IDictionary<string, Tensor>>()[slice_spec] = spec.tensor; | |||
| tensor_dict.SetDefault(name, new Dictionary<string, Tensor>()).AsT1[slice_spec] = spec.tensor; | |||
| } | |||
| else | |||
| { | |||
| @@ -343,7 +343,7 @@ namespace Tensorflow | |||
| /// Generates `Trackable._restore_from_tensors` from SaveableObjects. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public static Func<IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>, IDictionary<string, Operation>> saveable_object_to_restore_fn(IList<MySaveableObject> saveables) | |||
| public static Func<IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>, IDictionary<string, Operation>> saveable_object_to_restore_fn(IList<MySaveableObject> saveables) | |||
| { | |||
| return (restored_tensors) => | |||
| { | |||
| @@ -359,14 +359,14 @@ namespace Tensorflow | |||
| var maybe_tensor = restored_tensors[name]; | |||
| IDictionary<string, Tensor> dict; | |||
| if(maybe_tensor.TryGet<Tensor>(out var tensor)) | |||
| if(maybe_tensor.TryPickT0(out var tensor, out var dic)) | |||
| { | |||
| dict = new Dictionary<string, Tensor>(); | |||
| dict[""] = tensor; | |||
| } | |||
| else | |||
| { | |||
| dict = maybe_tensor.GetValue<IDictionary<string, Tensor>>(); | |||
| dict = dic; | |||
| } | |||
| saveable_restored_tensors.Add(dict[slice_spec]); | |||
| } | |||
| @@ -381,18 +381,18 @@ namespace Tensorflow | |||
| /// </summary> | |||
| /// <param name="saveable_fn_by_name"></param> | |||
| /// <param name="temp_session"></param> | |||
| public static IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> recreate_saveable_objects( | |||
| public static IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> recreate_saveable_objects( | |||
| IDictionary<string, (Trackable, Trackable)> saveable_fn_by_name, IEnumerable<object>? temp_session) | |||
| { | |||
| if (saveable_fn_by_name.Count > 0) | |||
| { | |||
| throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
| } | |||
| var res = new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>(); | |||
| var res = new Dictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>>(); | |||
| return res; | |||
| } | |||
| public static Maybe<BaseResourceVariable, MySaveableObject> create_saveable_object(string name, string key, Func<string, Maybe<BaseResourceVariable, MySaveableObject>> factory, | |||
| public static OneOf<BaseResourceVariable, MySaveableObject> create_saveable_object(string name, string key, Func<string, OneOf<BaseResourceVariable, MySaveableObject>> factory, | |||
| bool call_with_mapped_captures = false) | |||
| { | |||
| return factory(key); | |||
| @@ -412,7 +412,7 @@ namespace Tensorflow | |||
| public object Obj => _obj; | |||
| public IList<MySaveableObject> mySaveables=> _saveables; | |||
| public override IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors() | |||
| public override IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors() | |||
| { | |||
| return saveable_object_util.saveable_object_to_tensor_dict(_saveables); | |||
| } | |||
| @@ -422,7 +422,7 @@ namespace Tensorflow | |||
| /// </summary> | |||
| /// <param name="restored_tensors"></param> | |||
| /// <returns></returns> | |||
| public override IDictionary<string, Operation> _restore_from_tensors(IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors) | |||
| public override IDictionary<string, Operation> _restore_from_tensors(IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> restored_tensors) | |||
| { | |||
| List<string> expected_keys = new(); | |||
| foreach(var saveable in _saveables) | |||
| @@ -14,6 +14,7 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using OneOf; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| @@ -43,8 +44,8 @@ namespace Tensorflow.Train | |||
| protected IList<TrackableReference> _unconditional_checkpoint_dependencies; | |||
| protected Dictionary<string, IList<CheckpointPosition>> _unconditional_deferred_dependencies; | |||
| protected IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> _self_saveable_object_factories = | |||
| new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>(); | |||
| protected IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> _self_saveable_object_factories = | |||
| new Dictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>>(); | |||
| private bool _manual_tracking = true; | |||
| private static Trackable _none = new AutoTrackable(); | |||
| @@ -73,7 +74,7 @@ namespace Tensorflow.Train | |||
| public IDictionary<string, Trackable> UnconditionalDependencyNames { get => _unconditional_dependency_names; } | |||
| public IList<TrackableReference> CheckpointDependencies { get => UnconditionalCheckpointDependencies; } | |||
| public Dictionary<string, IList<CheckpointPosition>> DeferredDependencies => _unconditional_deferred_dependencies; | |||
| public IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> SelfSaveableObjectFactories | |||
| public IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> SelfSaveableObjectFactories | |||
| { | |||
| get | |||
| { | |||
| @@ -249,9 +250,9 @@ namespace Tensorflow.Train | |||
| return self_tensor_map.Keys.ToList(); | |||
| } | |||
| public virtual IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint() | |||
| public virtual IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint() | |||
| { | |||
| Maybe<BaseResourceVariable, MySaveableObject> create_saveable(string name = "") | |||
| OneOf<BaseResourceVariable, MySaveableObject> create_saveable(string name = "") | |||
| { | |||
| throw new NotImplementedException(); | |||
| //return new TrackableSaveable(this, null, name, null, null); | |||
| @@ -259,7 +260,7 @@ namespace Tensorflow.Train | |||
| if (saveable_object_util.trackable_has_serialize_to_tensor(this)) | |||
| { | |||
| // TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`). | |||
| Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> res = new(); | |||
| Dictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> res = new(); | |||
| res[""] = create_saveable; | |||
| return res; | |||
| } | |||
| @@ -278,12 +279,12 @@ namespace Tensorflow.Train | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| /// <exception cref="NotImplementedException"></exception> | |||
| public virtual IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors() | |||
| public virtual IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors() | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public virtual IDictionary<string, Operation> _restore_from_tensors(IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors) | |||
| public virtual IDictionary<string, Operation> _restore_from_tensors(IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> restored_tensors) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| @@ -0,0 +1,23 @@ | |||
| using Google.Protobuf; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Util | |||
| { | |||
| internal static class function_utils | |||
| { | |||
| private static string _rewriter_config_optimizer_disabled; | |||
| public static string get_disabled_rewriter_config() | |||
| { | |||
| if(_rewriter_config_optimizer_disabled is null) | |||
| { | |||
| var config = new ConfigProto(); | |||
| var rewriter_config = config.GraphOptions.RewriteOptions; | |||
| rewriter_config.DisableMetaOptimizer = true; | |||
| _rewriter_config_optimizer_disabled = config.ToString(); | |||
| } | |||
| return _rewriter_config_optimizer_disabled; | |||
| } | |||
| } | |||
| } | |||
| @@ -8,6 +8,7 @@ using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using Tensorflow.Checkpoint; | |||
| using Tensorflow.Training.Saving.SavedModel; | |||
| using OneOf; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -155,7 +156,7 @@ namespace Tensorflow | |||
| { | |||
| variable_accessed(this); | |||
| var result = gen_resource_variable_ops.read_variable_op(handle, _dtype); | |||
| // _maybe_set_handle_data(_dtype, _handle, result); | |||
| resource_variable_ops._maybe_set_handle_data(_dtype, handle, result); | |||
| // have to set shape when converting to substituent placeholder | |||
| if (result.shape.ndim == -1) | |||
| @@ -293,9 +294,9 @@ namespace Tensorflow | |||
| resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options); | |||
| } | |||
| public override IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint() | |||
| public override IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint() | |||
| { | |||
| var res = new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>(); | |||
| var res = new Dictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>>(); | |||
| res[Trackable.Constants.VARIABLE_VALUE_KEY] = x => this; | |||
| return res; | |||
| } | |||
| @@ -124,7 +124,9 @@ namespace Tensorflow | |||
| initializer_op = gen_state_ops.assign(handle, _initial_value, true).op; | |||
| ops.colocate_with(initializer_op); | |||
| tf.device(handle.Device); | |||
| var value = gen_resource_variable_ops.read_variable_op(handle, dtype); | |||
| resource_variable_ops._maybe_set_handle_data(dtype, handle, value); | |||
| _graph_element = gen_array_ops.identity(handle, name = "read"); | |||
| ops.add_to_collections<IVariableV1>(collections, this); | |||
| _dtype = handle.dtype; | |||
| @@ -141,6 +143,12 @@ namespace Tensorflow | |||
| gen_resource_variable_ops.assign_variable_op(handle, _initial_value); | |||
| initializer_op = null; | |||
| _graph_element = null; | |||
| if (!string.IsNullOrEmpty(caching_device)) | |||
| { | |||
| tf.device(caching_device); | |||
| var value = gen_resource_variable_ops.read_variable_op(handle, dtype); | |||
| resource_variable_ops._maybe_set_handle_data(dtype, handle, value); | |||
| } | |||
| _dtype = _initial_value.dtype.as_base_dtype(); | |||
| // initial_value = _in_graph_mode ? initial_value : null; | |||
| } | |||
| @@ -9,7 +9,7 @@ namespace Tensorflow.Variables | |||
| /// <summary> | |||
| /// A variable with no initializer. | |||
| /// </summary> | |||
| public sealed class UninitializedVariable: BaseResourceVariable | |||
| public sealed class UninitializedVariable: BaseResourceVariable, IVariableV1 | |||
| { | |||
| // TODO: complete the arg list. | |||
| public UninitializedVariable( | |||
| @@ -23,6 +23,7 @@ namespace Tensorflow.Variables | |||
| { | |||
| string unique_id = ""; | |||
| string handle_name = ""; | |||
| Tensor created_handle = null; | |||
| tf_with(ops.init_scope(), (x) => | |||
| { | |||
| _in_graph_mode = !tf.Context.executing_eagerly(); | |||
| @@ -40,7 +41,7 @@ namespace Tensorflow.Variables | |||
| unique_id = $"{handle_name}-{ops.uid()}"; | |||
| shared_name = null; | |||
| } | |||
| var handle = resource_variable_ops.variable_handle_from_shape_and_dtype( | |||
| created_handle = resource_variable_ops.variable_handle_from_shape_and_dtype( | |||
| shape, dtype, shared_name, name, _in_graph_mode, extra_handle_data); | |||
| // skip the assignment of `handle._parent_trackable` because of lack of API. | |||
| // skip the assignment of `handle._name` and `handle._unique_id` because of accessability. | |||
| @@ -51,7 +52,7 @@ namespace Tensorflow.Variables | |||
| { | |||
| tf.device(handle.Device); | |||
| var value = gen_resource_variable_ops.read_variable_op(handle, dtype); | |||
| // _maybe_set_handle_data(dtype, handle, value) | |||
| resource_variable_ops._maybe_set_handle_data(dtype, handle, value); | |||
| _graph_element = value; | |||
| }); | |||
| ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this); | |||
| @@ -64,7 +65,7 @@ namespace Tensorflow.Variables | |||
| }); | |||
| _shape = shape; | |||
| _dtype = dtype; | |||
| base.__init__(trainable, handle, unique_id: unique_id, handle_name: handle_name); | |||
| base.__init__(trainable, created_handle, unique_id: unique_id, handle_name: handle_name); | |||
| } | |||
| } | |||
| } | |||
| @@ -26,6 +26,7 @@ using Tensorflow.Eager; | |||
| using Tensorflow.Graphs; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.CppShapeInferenceResult.Types; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -572,9 +573,12 @@ namespace Tensorflow | |||
| return get_default_graph().building_function; | |||
| } | |||
| public static SafeTensorHandle get_resource_handle_data(Tensor graph_op) | |||
| public static HandleData get_resource_handle_data(Tensor graph_op) | |||
| { | |||
| throw new NotImplementedException(); | |||
| // This implementation hasn't been checked for some reasons. | |||
| // If it throws an exception in the future, please check it. | |||
| var handle_data = c_api.GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); | |||
| return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data))); | |||
| } | |||
| public static void dismantle_graph(Graph graph) | |||
| @@ -40,7 +40,7 @@ namespace Tensorflow.Keras.Engine | |||
| /// <summary> | |||
| /// Arguments initialize layer. | |||
| /// </summary> | |||
| LayerArgs args; | |||
| internal LayerArgs args; | |||
| /// <summary> | |||
| /// Indicates whether `build` needs to be called upon layer call, to create | |||
| @@ -147,7 +147,7 @@ namespace Tensorflow.Keras.Engine | |||
| List<INode> outboundNodes; | |||
| public List<INode> OutboundNodes => outboundNodes; | |||
| public JObject SerializedAttributes { get; set; } | |||
| public Dictionary<string, object> SerializedAttributes { get; set; } | |||
| ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>(); | |||
| public CallContext CallContext => callContext.Value; | |||
| @@ -26,7 +26,7 @@ namespace Tensorflow.Keras.Saving | |||
| { | |||
| public class KerasObjectLoader | |||
| { | |||
| internal static readonly IDictionary<string, Trackable> PUBLIC_ATTRIBUTES = new CommonEndPoints().CheckpointableObjects; | |||
| internal static readonly IDictionary<string, Trackable> PUBLIC_ATTRIBUTES; | |||
| private SavedMetadata _metadata; | |||
| private SavedObjectGraph _proto; | |||
| private Dictionary<int, string> _node_paths = new Dictionary<int, string>(); | |||
| @@ -39,7 +39,13 @@ namespace Tensorflow.Keras.Saving | |||
| static KerasObjectLoader() | |||
| { | |||
| PUBLIC_ATTRIBUTES[Keras.Saving.SavedModel.Constants.KERAS_ATTR] = null; | |||
| var endPoints = new CommonEndPoints(); | |||
| PUBLIC_ATTRIBUTES = new Dictionary<string, Trackable>(); | |||
| foreach (var key in endPoints._all_checkpointable_objects.Concat(endPoints._all_functions)) | |||
| { | |||
| PUBLIC_ATTRIBUTES[key] = null; | |||
| } | |||
| PUBLIC_ATTRIBUTES[SavedModel.Constants.KERAS_ATTR] = null; | |||
| } | |||
| public KerasObjectLoader(SavedMetadata metadata, SavedObjectGraph object_graph_def) | |||
| @@ -125,8 +131,14 @@ namespace Tensorflow.Keras.Saving | |||
| continue; | |||
| } | |||
| // TODO: deal with `RevivedLayer` and `RevivedInputLayer`. | |||
| layers_revived_from_config.Add(node as Layer); | |||
| if(node is RevivedLayer or RevivedInputLayer) | |||
| { | |||
| layers_revived_from_saved_model.Add(node as Layer); | |||
| } | |||
| else | |||
| { | |||
| layers_revived_from_config.Add(node as Layer); | |||
| } | |||
| } | |||
| _finalize_saved_model_layers(layers_revived_from_saved_model); | |||
| @@ -171,10 +183,13 @@ namespace Tensorflow.Keras.Saving | |||
| // TODO(Rinne): implement it | |||
| } | |||
| } | |||
| // `model.__init__(layers, config["name"])` | |||
| s.InitLayers(layers); | |||
| s.Name = config["name"].ToObject<string>(); | |||
| // `model.__init__(layers, config["name"])`InitLayers(layers); | |||
| s = new Sequential(new SequentialArgs(){ | |||
| Layers = layers.Select(x => x as ILayer).ToList(), | |||
| Name = config["name"].ToObject<string>() | |||
| }); | |||
| //s.Name = config["name"].ToObject<string>(); | |||
| if(s.input is null || s.input.Length == 0) | |||
| { | |||
| var first_layer = _get_child_layer_node_ids(model_id)[0]; | |||
| @@ -205,7 +220,12 @@ namespace Tensorflow.Keras.Saving | |||
| private void _set_network_attributes_from_metadata(Model revived_object) | |||
| { | |||
| // TODO: implement it. | |||
| var metadata = revived_object.SerializedAttributes["matadata"] as JObject; | |||
| if (metadata.ContainsKey("dtype")) | |||
| { | |||
| // TODO(Rinne): set_dtype_policy. | |||
| } | |||
| revived_object.args.Trainable = metadata["trainable"].Value<bool>(); | |||
| } | |||
| /// <summary> | |||
| @@ -330,7 +350,7 @@ namespace Tensorflow.Keras.Saving | |||
| private (Trackable, Action<object, object, object>) _revive_from_config(string identifier, KerasMetaData metadata, int node_id) | |||
| { | |||
| Trackable obj; | |||
| if(identifier == Keras.Saving.SavedModel.Constants.METRIC_IDENTIFIER) | |||
| if(identifier == SavedModel.Constants.METRIC_IDENTIFIER) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| return (null, null); | |||
| @@ -429,25 +449,26 @@ namespace Tensorflow.Keras.Saving | |||
| return obj; | |||
| } | |||
| private void _revive_setter(object layer, object name, object value) | |||
| private void _revive_setter(object obj, object name, object value) | |||
| { | |||
| Debug.Assert(name is string); | |||
| Debug.Assert(layer is Layer); | |||
| Debug.Assert(obj is Layer); | |||
| Layer layer = (Layer)obj; | |||
| if(PUBLIC_ATTRIBUTES.ContainsKey(name as string)) | |||
| { | |||
| if(value is Trackable) | |||
| { | |||
| (layer as Layer)._track_trackable(value as Trackable, name as string); | |||
| layer._track_trackable(value as Trackable, name as string); | |||
| } | |||
| if((layer as Layer).SerializedAttributes is null) | |||
| if(layer.SerializedAttributes is null) | |||
| { | |||
| (layer as Layer).SerializedAttributes = new JObject(); | |||
| layer.SerializedAttributes = new Dictionary<string, object>(); | |||
| } | |||
| (layer as Layer).SerializedAttributes[name as string] = JToken.FromObject(value); | |||
| layer.SerializedAttributes[name as string] = value; | |||
| } | |||
| else if(layer is Functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success) | |||
| else if(layer is Functional functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success) | |||
| { | |||
| (layer as Functional)._track_trackable(value as Trackable, name as string, overwrite: true); | |||
| functional._track_trackable(value as Trackable, name as string, overwrite: true); | |||
| } | |||
| else | |||
| { | |||
| @@ -521,7 +542,7 @@ namespace Tensorflow.Keras.Saving | |||
| } | |||
| var metric_list_node_id = _search_for_child_node(node_id, new string[] { | |||
| Keras.Saving.SavedModel.Constants.KERAS_ATTR, "layer_metrics" | |||
| SavedModel.Constants.KERAS_ATTR, "layer_metrics" | |||
| }); | |||
| if(metric_list_node_id is not null && obj is Model model && model.metrics is not null) | |||
| { | |||
| @@ -547,7 +568,7 @@ namespace Tensorflow.Keras.Saving | |||
| // skip the check for registered identifier | |||
| Action<object, object, object> setter; | |||
| if (Keras.Saving.SavedModel.Constants.KERAS_OBJECT_IDENTIFIERS.Contains(obj_child.ObjectIdentifier)) | |||
| if (SavedModel.Constants.KERAS_OBJECT_IDENTIFIERS.Contains(obj_child.ObjectIdentifier)) | |||
| { | |||
| setter = _revive_setter; | |||
| } | |||
| @@ -659,7 +680,23 @@ namespace Tensorflow.Keras.Saving | |||
| private void _maybe_add_serialized_attributes(Layer layer, KerasMetaData metadata) | |||
| { | |||
| // TODO: deal with `RevivedLayer` | |||
| if(layer.SerializedAttributes is null || layer.SerializedAttributes.Count == 0) | |||
| { | |||
| layer.SerializedAttributes = new Dictionary<string, object>(); | |||
| layer.SerializedAttributes["metadata"] = metadata; | |||
| } | |||
| } | |||
| private static object _get_keras_attr(Layer layer) | |||
| { | |||
| if((layer.SerializedAttributes ?? new Dictionary<string, object>()).TryGetValue(SavedModel.Constants.KERAS_ATTR, out var value)) | |||
| { | |||
| return value; | |||
| } | |||
| else | |||
| { | |||
| return null; | |||
| } | |||
| } | |||
| /// <summary> | |||
| @@ -24,17 +24,22 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||
| } | |||
| } | |||
| public static void _revive_setter(object layer, object name, object value) | |||
| public static void _revive_setter(object obj, object name, object value) | |||
| { | |||
| Debug.Assert(name is string); | |||
| Debug.Assert(layer is Layer); | |||
| Debug.Assert(obj is Layer); | |||
| Layer layer = (Layer)obj; | |||
| if (KerasObjectLoader.PUBLIC_ATTRIBUTES.ContainsKey(name as string)) | |||
| { | |||
| if (value is Trackable trackable) | |||
| { | |||
| (layer as Layer)._track_trackable(trackable, name as string); | |||
| layer._track_trackable(trackable, name as string); | |||
| } | |||
| (layer as Layer).SerializedAttributes[name] = JToken.FromObject(value); | |||
| if (layer.SerializedAttributes is null) | |||
| { | |||
| layer.SerializedAttributes = new Dictionary<string, object>(); | |||
| } | |||
| layer.SerializedAttributes[name as string] = value; | |||
| } | |||
| else if (layer is Functional functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success) | |||
| { | |||
| @@ -0,0 +1,15 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.Engine; | |||
| namespace Tensorflow.Keras.Saving.SavedModel | |||
| { | |||
| public class RevivedInputLayer: Layer | |||
| { | |||
| private RevivedInputLayer(): base(null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| } | |||
| @@ -55,6 +55,21 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||
| private RevivedConfig _config = null; | |||
| public object keras_api | |||
| { | |||
| get | |||
| { | |||
| if (SerializedAttributes.TryGetValue(SavedModel.Constants.KERAS_ATTR, out var value)) | |||
| { | |||
| return value; | |||
| } | |||
| else | |||
| { | |||
| return null; | |||
| } | |||
| } | |||
| } | |||
| public RevivedLayer(LayerArgs args): base(args) | |||
| { | |||
| @@ -69,5 +84,17 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||
| { | |||
| return _config; | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| { | |||
| if(SerializedAttributes is null || !SerializedAttributes.TryGetValue("__call__", out var func) || func is not Function) | |||
| { | |||
| return base.Call(inputs, state, training); | |||
| } | |||
| else | |||
| { | |||
| return (func as Function).Apply(inputs); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -19,8 +19,8 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||
| protected IDictionary<string, Trackable?> _object_dict; | |||
| protected IDictionary<string, Trackable?> _function_dict; | |||
| protected AutoTrackable _keras_trackable; | |||
| protected HashSet<string> _all_functions; | |||
| protected HashSet<string> _all_checkpointable_objects; | |||
| internal HashSet<string> _all_functions; | |||
| internal HashSet<string> _all_checkpointable_objects; | |||
| private SerializedAttributes() | |||
| { | |||
| @@ -197,19 +197,15 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||
| public class CommonEndPoints: SerializedAttributes | |||
| { | |||
| public CommonEndPoints(IEnumerable<string> checkpointable_objects, IEnumerable<string> functions) : | |||
| //base(checkpointable_objects.Concat(new string[] { "variables", "trainable_variables", "regularization_losses" }), | |||
| // functions.Concat(new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" })) | |||
| base(checkpointable_objects.Concat(new string[] { "variables", "trainable_variables"}), | |||
| functions.Concat(new string[] { })) | |||
| base(checkpointable_objects.Concat(new string[] { "variables", "trainable_variables", "regularization_losses" }), | |||
| functions.Concat(new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" })) | |||
| { | |||
| } | |||
| public CommonEndPoints() : | |||
| //base(new string[] { "variables", "trainable_variables", "regularization_losses" }, | |||
| // new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" }) | |||
| base(new string[] { "variables", "trainable_variables"}, | |||
| new string[] {}) | |||
| base(new string[] { "variables", "trainable_variables", "regularization_losses" }, | |||
| new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" }) | |||
| { | |||
| } | |||