| @@ -80,6 +80,14 @@ namespace Tensorflow | |||
| { | |||
| return np.array(tensor.IntVal.ToArray()).reshape(shape); | |||
| } | |||
| else if (new DataType[] { DataType.DtInt64 }.Contains(tensor.Dtype)) | |||
| { | |||
| return np.array(tensor.Int64Val.ToArray()).reshape(shape); | |||
| } | |||
| else if (new DataType[] { DataType.DtUint64 }.Contains(tensor.Dtype)) | |||
| { | |||
| return np.array(tensor.Uint64Val.ToArray()).reshape(shape); | |||
| } | |||
| else if (tensor.Dtype == DataType.DtBool) | |||
| { | |||
| return np.array(tensor.BoolVal.ToArray()).reshape(shape); | |||
| @@ -1,11 +0,0 @@ | |||
| using Tensorflow.Train; | |||
| namespace Tensorflow.Trackables; | |||
| public class Asset : Trackable | |||
| { | |||
| public static (Trackable, Action<object, object, object>) deserialize_from_proto() | |||
| { | |||
| return (null, null); | |||
| } | |||
| } | |||
| @@ -0,0 +1,18 @@ | |||
| using Google.Protobuf.Collections; | |||
| using System.IO; | |||
| using Tensorflow.Train; | |||
| namespace Tensorflow.Trackables; | |||
| public class AssetResource : Trackable | |||
| { | |||
| public static (Trackable, Action<object, object, object>) deserialize_from_proto(SavedObject object_proto, | |||
| string export_dir, | |||
| RepeatedField<AssetFileDef> asset_file_def, | |||
| Dictionary<string, MapField<string, AttrValue>> operation_attributes) | |||
| { | |||
| var proto = object_proto.Asset; | |||
| var filename = Path.Combine(export_dir, asset_file_def[proto.AssetFileDefIndex].Filename); | |||
| return (new AssetResource(), null); | |||
| } | |||
| } | |||
| @@ -1,12 +1,13 @@ | |||
| using System.Runtime.CompilerServices; | |||
| using Google.Protobuf.Collections; | |||
| using Tensorflow.Train; | |||
| namespace Tensorflow.Trackables; | |||
| public class RestoredResource : TrackableResource | |||
| { | |||
| public static (Trackable, Action<object, object, object>) deserialize_from_proto() | |||
| public static (Trackable, Action<object, object, object>) deserialize_from_proto(SavedObject object_proto, | |||
| Dictionary<string, MapField<string, AttrValue>> operation_attributes) | |||
| { | |||
| return (null, null); | |||
| return (new RestoredResource(), null); | |||
| } | |||
| } | |||
| @@ -1,11 +1,22 @@ | |||
| using Tensorflow.Train; | |||
| using Google.Protobuf.Collections; | |||
| using Tensorflow.Train; | |||
| namespace Tensorflow.Trackables; | |||
| public class TrackableConstant : Trackable | |||
| { | |||
| public static (Trackable, Action<object, object, object>) deserialize_from_proto() | |||
| Tensor _constant; | |||
| public TrackableConstant(Tensor constant) | |||
| { | |||
| return (null, null); | |||
| _constant = constant; | |||
| } | |||
| public static (Trackable, Action<object, object, object>) deserialize_from_proto(SavedObject object_proto, | |||
| Dictionary<string, MapField<string, AttrValue>> operation_attributes) | |||
| { | |||
| var tensor_proto = operation_attributes[object_proto.Constant.Operation]["value"].Tensor; | |||
| var ndarray = tensor_util.MakeNdarray(tensor_proto); | |||
| var imported_constant = constant_op.constant(ndarray); | |||
| return (new TrackableConstant(imported_constant), null); | |||
| } | |||
| } | |||
| @@ -9,6 +9,19 @@ namespace Tensorflow.Training.Saving.SavedModel | |||
| { | |||
| public static class function_deserialization | |||
| { | |||
| /// <summary> | |||
| /// Creates a `Function` from a `SavedFunction`. | |||
| /// </summary> | |||
| /// <param name="saved_concrete_function"></param> | |||
| /// <param name="concrete_functions"></param> | |||
| /// <returns></returns> | |||
| public static ConcreteFunction recreate_function(SavedFunction saved_concrete_function, | |||
| IDictionary<string, ConcreteFunction> concrete_functions) | |||
| { | |||
| var function_spec = _deserialize_function_spec_as_nonmethod(saved_concrete_function.FunctionSpec); | |||
| return null; | |||
| } | |||
| public static ConcreteFunction setup_bare_concrete_function(SavedBareConcreteFunction saved_bare_concrete_function, | |||
| IDictionary<string, ConcreteFunction> concrete_functions) | |||
| { | |||
| @@ -387,13 +387,6 @@ namespace Tensorflow | |||
| } | |||
| else | |||
| { | |||
| // skip the function and concrete function. | |||
| if(proto.KindCase == SavedObject.KindOneofCase.BareConcreteFunction || proto.KindCase == SavedObject.KindOneofCase.Function) | |||
| { | |||
| nodes[node_id] = null; | |||
| node_setters[node_id] = null; | |||
| continue; | |||
| } | |||
| var (node, setter) = _recreate(proto, node_id, nodes); | |||
| nodes[node_id] = node; | |||
| node_setters[node_id] = setter; | |||
| @@ -471,6 +464,11 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| private void _setup_function_captures() | |||
| { | |||
| // TODO: implement it with concrete functions. | |||
| } | |||
| private void _setup_remaining_functions() | |||
| { | |||
| // TODO: implement it with concrete functions. | |||
| @@ -542,9 +540,9 @@ namespace Tensorflow | |||
| return proto.KindCase switch | |||
| { | |||
| SavedObject.KindOneofCase.Resource => RestoredResource.deserialize_from_proto(), | |||
| SavedObject.KindOneofCase.Asset => Asset.deserialize_from_proto(), | |||
| SavedObject.KindOneofCase.Constant => TrackableConstant.deserialize_from_proto(), | |||
| SavedObject.KindOneofCase.Resource => RestoredResource.deserialize_from_proto(proto, _operation_attributes), | |||
| SavedObject.KindOneofCase.Asset => AssetResource.deserialize_from_proto(proto, _export_dir, _asset_file_def, _operation_attributes), | |||
| SavedObject.KindOneofCase.Constant => TrackableConstant.deserialize_from_proto(proto, _operation_attributes), | |||
| _ => _recreate_default(proto, node_id, dependencies) | |||
| }; | |||
| } | |||
| @@ -563,7 +561,8 @@ namespace Tensorflow | |||
| SavedObject.KindOneofCase.Function => _recreate_function(proto.Function, null), | |||
| SavedObject.KindOneofCase.BareConcreteFunction => throw new NotImplementedException(), | |||
| SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable), | |||
| SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException() | |||
| SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException(), | |||
| _ => throw new NotImplementedException() | |||
| }; | |||
| } | |||
| @@ -623,8 +622,12 @@ namespace Tensorflow | |||
| private (ConcreteFunction, Action<object, object, object>) _recreate_function(SavedFunction proto, | |||
| Dictionary<Maybe<string, int>, Trackable> dependencies) | |||
| { | |||
| throw new NotImplementedException(); | |||
| //var fn = function_deserialization.setup_bare_concrete_function(proto, ) | |||
| var fn = function_deserialization.recreate_function(proto, null); | |||
| foreach (var name in proto.ConcreteFunctions) | |||
| { | |||
| _setup_function_captures(); | |||
| } | |||
| return (fn, setattr); | |||
| } | |||
| private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, | |||