From bdf229acbb1b1d5127bf05b5f2f80974bbf4342d Mon Sep 17 00:00:00 2001 From: Haiping Chen Date: Mon, 6 Mar 2023 20:50:37 -0600 Subject: [PATCH] Renmae to AssetResource. --- src/TensorFlowNET.Core/Tensors/tensor_util.cs | 8 +++++ src/TensorFlowNET.Core/Trackables/Asset.cs | 11 ------- .../Trackables/AssetResource.cs | 18 ++++++++++++ .../Trackables/RestoredResource.cs | 7 +++-- .../Trackables/TrackableConstant.cs | 17 +++++++++-- .../SavedModel/function_deserialization.cs | 13 +++++++++ .../Training/Saving/SavedModel/loader.cs | 29 ++++++++++--------- 7 files changed, 73 insertions(+), 30 deletions(-) delete mode 100644 src/TensorFlowNET.Core/Trackables/Asset.cs create mode 100644 src/TensorFlowNET.Core/Trackables/AssetResource.cs diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 7af89f13..19dbd6ed 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -80,6 +80,14 @@ namespace Tensorflow { return np.array(tensor.IntVal.ToArray()).reshape(shape); } + else if (new DataType[] { DataType.DtInt64 }.Contains(tensor.Dtype)) + { + return np.array(tensor.Int64Val.ToArray()).reshape(shape); + } + else if (new DataType[] { DataType.DtUint64 }.Contains(tensor.Dtype)) + { + return np.array(tensor.Uint64Val.ToArray()).reshape(shape); + } else if (tensor.Dtype == DataType.DtBool) { return np.array(tensor.BoolVal.ToArray()).reshape(shape); diff --git a/src/TensorFlowNET.Core/Trackables/Asset.cs b/src/TensorFlowNET.Core/Trackables/Asset.cs deleted file mode 100644 index cf4c6875..00000000 --- a/src/TensorFlowNET.Core/Trackables/Asset.cs +++ /dev/null @@ -1,11 +0,0 @@ -using Tensorflow.Train; - -namespace Tensorflow.Trackables; - -public class Asset : Trackable -{ - public static (Trackable, Action) deserialize_from_proto() - { - return (null, null); - } -} diff --git a/src/TensorFlowNET.Core/Trackables/AssetResource.cs b/src/TensorFlowNET.Core/Trackables/AssetResource.cs new file mode 100644 index 00000000..6e8d05a8 --- /dev/null +++ b/src/TensorFlowNET.Core/Trackables/AssetResource.cs @@ -0,0 +1,18 @@ +using Google.Protobuf.Collections; +using System.IO; +using Tensorflow.Train; + +namespace Tensorflow.Trackables; + +public class AssetResource : Trackable +{ + public static (Trackable, Action) deserialize_from_proto(SavedObject object_proto, + string export_dir, + RepeatedField asset_file_def, + Dictionary> operation_attributes) + { + var proto = object_proto.Asset; + var filename = Path.Combine(export_dir, asset_file_def[proto.AssetFileDefIndex].Filename); + return (new AssetResource(), null); + } +} diff --git a/src/TensorFlowNET.Core/Trackables/RestoredResource.cs b/src/TensorFlowNET.Core/Trackables/RestoredResource.cs index 0d1267d4..cb9f6aa0 100644 --- a/src/TensorFlowNET.Core/Trackables/RestoredResource.cs +++ b/src/TensorFlowNET.Core/Trackables/RestoredResource.cs @@ -1,12 +1,13 @@ -using System.Runtime.CompilerServices; +using Google.Protobuf.Collections; using Tensorflow.Train; namespace Tensorflow.Trackables; public class RestoredResource : TrackableResource { - public static (Trackable, Action) deserialize_from_proto() + public static (Trackable, Action) deserialize_from_proto(SavedObject object_proto, + Dictionary> operation_attributes) { - return (null, null); + return (new RestoredResource(), null); } } diff --git a/src/TensorFlowNET.Core/Trackables/TrackableConstant.cs b/src/TensorFlowNET.Core/Trackables/TrackableConstant.cs index 7e7f40ec..6de8274a 100644 --- a/src/TensorFlowNET.Core/Trackables/TrackableConstant.cs +++ b/src/TensorFlowNET.Core/Trackables/TrackableConstant.cs @@ -1,11 +1,22 @@ -using Tensorflow.Train; +using Google.Protobuf.Collections; +using Tensorflow.Train; namespace Tensorflow.Trackables; public class TrackableConstant : Trackable { - public static (Trackable, Action) deserialize_from_proto() + Tensor _constant; + public TrackableConstant(Tensor constant) { - return (null, null); + _constant = constant; + } + + public static (Trackable, Action) deserialize_from_proto(SavedObject object_proto, + Dictionary> operation_attributes) + { + var tensor_proto = operation_attributes[object_proto.Constant.Operation]["value"].Tensor; + var ndarray = tensor_util.MakeNdarray(tensor_proto); + var imported_constant = constant_op.constant(ndarray); + return (new TrackableConstant(imported_constant), null); } } diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs index 5b482872..d26fe2b5 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs @@ -9,6 +9,19 @@ namespace Tensorflow.Training.Saving.SavedModel { public static class function_deserialization { + /// + /// Creates a `Function` from a `SavedFunction`. + /// + /// + /// + /// + public static ConcreteFunction recreate_function(SavedFunction saved_concrete_function, + IDictionary concrete_functions) + { + var function_spec = _deserialize_function_spec_as_nonmethod(saved_concrete_function.FunctionSpec); + return null; + } + public static ConcreteFunction setup_bare_concrete_function(SavedBareConcreteFunction saved_bare_concrete_function, IDictionary concrete_functions) { diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs index d1ff95ca..dc9e5ba5 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs @@ -387,13 +387,6 @@ namespace Tensorflow } else { - // skip the function and concrete function. - if(proto.KindCase == SavedObject.KindOneofCase.BareConcreteFunction || proto.KindCase == SavedObject.KindOneofCase.Function) - { - nodes[node_id] = null; - node_setters[node_id] = null; - continue; - } var (node, setter) = _recreate(proto, node_id, nodes); nodes[node_id] = node; node_setters[node_id] = setter; @@ -471,6 +464,11 @@ namespace Tensorflow } } + private void _setup_function_captures() + { + // TODO: implement it with concrete functions. + } + private void _setup_remaining_functions() { // TODO: implement it with concrete functions. @@ -542,9 +540,9 @@ namespace Tensorflow return proto.KindCase switch { - SavedObject.KindOneofCase.Resource => RestoredResource.deserialize_from_proto(), - SavedObject.KindOneofCase.Asset => Asset.deserialize_from_proto(), - SavedObject.KindOneofCase.Constant => TrackableConstant.deserialize_from_proto(), + SavedObject.KindOneofCase.Resource => RestoredResource.deserialize_from_proto(proto, _operation_attributes), + SavedObject.KindOneofCase.Asset => AssetResource.deserialize_from_proto(proto, _export_dir, _asset_file_def, _operation_attributes), + SavedObject.KindOneofCase.Constant => TrackableConstant.deserialize_from_proto(proto, _operation_attributes), _ => _recreate_default(proto, node_id, dependencies) }; } @@ -563,7 +561,8 @@ namespace Tensorflow SavedObject.KindOneofCase.Function => _recreate_function(proto.Function, null), SavedObject.KindOneofCase.BareConcreteFunction => throw new NotImplementedException(), SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable), - SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException() + SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException(), + _ => throw new NotImplementedException() }; } @@ -623,8 +622,12 @@ namespace Tensorflow private (ConcreteFunction, Action) _recreate_function(SavedFunction proto, Dictionary, Trackable> dependencies) { - throw new NotImplementedException(); - //var fn = function_deserialization.setup_bare_concrete_function(proto, ) + var fn = function_deserialization.recreate_function(proto, null); + foreach (var name in proto.ConcreteFunctions) + { + _setup_function_captures(); + } + return (fn, setattr); } private (ConcreteFunction, Action) _recreate_bare_concrete_function(SavedBareConcreteFunction proto,