From 8550dccc56f9a33dcf89c3ec1e0a73bb1e88d2fb Mon Sep 17 00:00:00 2001 From: Haiping Chen Date: Sat, 4 Mar 2023 21:58:32 -0600 Subject: [PATCH] Add missing trackable class but not implemented. --- .../Operations/SafeOperationHandle.cs | 40 +++++++++++++++++++ src/TensorFlowNET.Core/Tensors/Tensors.cs | 12 ++++++ src/TensorFlowNET.Core/Trackables/Asset.cs | 11 +++++ .../Trackables/CapturableResource.cs | 7 ++++ .../Trackables/RestoredResource.cs | 12 ++++++ .../Trackables/TrackableConstant.cs | 11 +++++ .../Trackables/TrackableResource.cs | 5 +++ .../Training/Saving/SavedModel/loader.cs | 26 ++++++++++-- 8 files changed, 120 insertions(+), 4 deletions(-) create mode 100644 src/TensorFlowNET.Core/Operations/SafeOperationHandle.cs create mode 100644 src/TensorFlowNET.Core/Trackables/Asset.cs create mode 100644 src/TensorFlowNET.Core/Trackables/CapturableResource.cs create mode 100644 src/TensorFlowNET.Core/Trackables/RestoredResource.cs create mode 100644 src/TensorFlowNET.Core/Trackables/TrackableConstant.cs create mode 100644 src/TensorFlowNET.Core/Trackables/TrackableResource.cs diff --git a/src/TensorFlowNET.Core/Operations/SafeOperationHandle.cs b/src/TensorFlowNET.Core/Operations/SafeOperationHandle.cs new file mode 100644 index 00000000..41364fe6 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/SafeOperationHandle.cs @@ -0,0 +1,40 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.Util; + +namespace Tensorflow; + +public sealed class SafeOperationHandle : SafeTensorflowHandle +{ + private SafeOperationHandle() + { + } + + public SafeOperationHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { + var status = new Status(); + // c_api.TF_CloseSession(handle, status); + c_api.TF_DeleteSession(handle, status); + SetHandle(IntPtr.Zero); + return true; + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Tensors.cs b/src/TensorFlowNET.Core/Tensors/Tensors.cs index 7fa4dd44..60972775 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensors.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensors.cs @@ -65,6 +65,18 @@ namespace Tensorflow IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + public string[] StringData() + { + EnsureSingleTensor(this, "nnumpy"); + return this[0].StringData(); + } + + public string StringData(int index) + { + EnsureSingleTensor(this, "nnumpy"); + return this[0].StringData(index); + } + public NDArray numpy() { EnsureSingleTensor(this, "nnumpy"); diff --git a/src/TensorFlowNET.Core/Trackables/Asset.cs b/src/TensorFlowNET.Core/Trackables/Asset.cs new file mode 100644 index 00000000..cf4c6875 --- /dev/null +++ b/src/TensorFlowNET.Core/Trackables/Asset.cs @@ -0,0 +1,11 @@ +using Tensorflow.Train; + +namespace Tensorflow.Trackables; + +public class Asset : Trackable +{ + public static (Trackable, Action) deserialize_from_proto() + { + return (null, null); + } +} diff --git a/src/TensorFlowNET.Core/Trackables/CapturableResource.cs b/src/TensorFlowNET.Core/Trackables/CapturableResource.cs new file mode 100644 index 00000000..d93f786d --- /dev/null +++ b/src/TensorFlowNET.Core/Trackables/CapturableResource.cs @@ -0,0 +1,7 @@ +using Tensorflow.Train; + +namespace Tensorflow.Trackables; + +public class CapturableResource : Trackable +{ +} diff --git a/src/TensorFlowNET.Core/Trackables/RestoredResource.cs b/src/TensorFlowNET.Core/Trackables/RestoredResource.cs new file mode 100644 index 00000000..0d1267d4 --- /dev/null +++ b/src/TensorFlowNET.Core/Trackables/RestoredResource.cs @@ -0,0 +1,12 @@ +using System.Runtime.CompilerServices; +using Tensorflow.Train; + +namespace Tensorflow.Trackables; + +public class RestoredResource : TrackableResource +{ + public static (Trackable, Action) deserialize_from_proto() + { + return (null, null); + } +} diff --git a/src/TensorFlowNET.Core/Trackables/TrackableConstant.cs b/src/TensorFlowNET.Core/Trackables/TrackableConstant.cs new file mode 100644 index 00000000..7e7f40ec --- /dev/null +++ b/src/TensorFlowNET.Core/Trackables/TrackableConstant.cs @@ -0,0 +1,11 @@ +using Tensorflow.Train; + +namespace Tensorflow.Trackables; + +public class TrackableConstant : Trackable +{ + public static (Trackable, Action) deserialize_from_proto() + { + return (null, null); + } +} diff --git a/src/TensorFlowNET.Core/Trackables/TrackableResource.cs b/src/TensorFlowNET.Core/Trackables/TrackableResource.cs new file mode 100644 index 00000000..43cbc5a2 --- /dev/null +++ b/src/TensorFlowNET.Core/Trackables/TrackableResource.cs @@ -0,0 +1,5 @@ +namespace Tensorflow.Trackables; + +public class TrackableResource : CapturableResource +{ +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs index da999b37..d1ff95ca 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs @@ -13,6 +13,7 @@ using System.Runtime.CompilerServices; using Tensorflow.Variables; using Tensorflow.Functions; using Tensorflow.Training.Saving.SavedModel; +using Tensorflow.Trackables; namespace Tensorflow { @@ -51,9 +52,13 @@ namespace Tensorflow _node_filters = filters; _node_path_to_id = _convert_node_paths_to_ints(); _loaded_nodes = new Dictionary)>(); - foreach(var filter in filters) + + if (filters != null) { - _loaded_nodes[_node_path_to_id[filter.Key]] = filter.Value; + foreach (var filter in filters) + { + _loaded_nodes[_node_path_to_id[filter.Key]] = filter.Value; + } } _filtered_nodes = _retrieve_all_filtered_nodes(); @@ -535,7 +540,13 @@ namespace Tensorflow dependencies[item.Key] = nodes[item.Value]; } - return _recreate_default(proto, node_id, dependencies); + return proto.KindCase switch + { + SavedObject.KindOneofCase.Resource => RestoredResource.deserialize_from_proto(), + SavedObject.KindOneofCase.Asset => Asset.deserialize_from_proto(), + SavedObject.KindOneofCase.Constant => TrackableConstant.deserialize_from_proto(), + _ => _recreate_default(proto, node_id, dependencies) + }; } /// @@ -549,7 +560,7 @@ namespace Tensorflow return proto.KindCase switch { SavedObject.KindOneofCase.UserObject => _recreate_user_object(proto.UserObject, node_id), - SavedObject.KindOneofCase.Function => throw new NotImplementedException(), + SavedObject.KindOneofCase.Function => _recreate_function(proto.Function, null), SavedObject.KindOneofCase.BareConcreteFunction => throw new NotImplementedException(), SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable), SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException() @@ -609,6 +620,13 @@ namespace Tensorflow } } + private (ConcreteFunction, Action) _recreate_function(SavedFunction proto, + Dictionary, Trackable> dependencies) + { + throw new NotImplementedException(); + //var fn = function_deserialization.setup_bare_concrete_function(proto, ) + } + private (ConcreteFunction, Action) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, Dictionary, Trackable> dependencies) {