| @@ -80,6 +80,14 @@ namespace Tensorflow | |||||
| { | { | ||||
| return np.array(tensor.IntVal.ToArray()).reshape(shape); | return np.array(tensor.IntVal.ToArray()).reshape(shape); | ||||
| } | } | ||||
| else if (new DataType[] { DataType.DtInt64 }.Contains(tensor.Dtype)) | |||||
| { | |||||
| return np.array(tensor.Int64Val.ToArray()).reshape(shape); | |||||
| } | |||||
| else if (new DataType[] { DataType.DtUint64 }.Contains(tensor.Dtype)) | |||||
| { | |||||
| return np.array(tensor.Uint64Val.ToArray()).reshape(shape); | |||||
| } | |||||
| else if (tensor.Dtype == DataType.DtBool) | else if (tensor.Dtype == DataType.DtBool) | ||||
| { | { | ||||
| return np.array(tensor.BoolVal.ToArray()).reshape(shape); | return np.array(tensor.BoolVal.ToArray()).reshape(shape); | ||||
| @@ -1,11 +0,0 @@ | |||||
| using Tensorflow.Train; | |||||
| namespace Tensorflow.Trackables; | |||||
| public class Asset : Trackable | |||||
| { | |||||
| public static (Trackable, Action<object, object, object>) deserialize_from_proto() | |||||
| { | |||||
| return (null, null); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,18 @@ | |||||
| using Google.Protobuf.Collections; | |||||
| using System.IO; | |||||
| using Tensorflow.Train; | |||||
| namespace Tensorflow.Trackables; | |||||
| public class AssetResource : Trackable | |||||
| { | |||||
| public static (Trackable, Action<object, object, object>) deserialize_from_proto(SavedObject object_proto, | |||||
| string export_dir, | |||||
| RepeatedField<AssetFileDef> asset_file_def, | |||||
| Dictionary<string, MapField<string, AttrValue>> operation_attributes) | |||||
| { | |||||
| var proto = object_proto.Asset; | |||||
| var filename = Path.Combine(export_dir, asset_file_def[proto.AssetFileDefIndex].Filename); | |||||
| return (new AssetResource(), null); | |||||
| } | |||||
| } | |||||
| @@ -1,12 +1,13 @@ | |||||
| using System.Runtime.CompilerServices; | |||||
| using Google.Protobuf.Collections; | |||||
| using Tensorflow.Train; | using Tensorflow.Train; | ||||
| namespace Tensorflow.Trackables; | namespace Tensorflow.Trackables; | ||||
| public class RestoredResource : TrackableResource | public class RestoredResource : TrackableResource | ||||
| { | { | ||||
| public static (Trackable, Action<object, object, object>) deserialize_from_proto() | |||||
| public static (Trackable, Action<object, object, object>) deserialize_from_proto(SavedObject object_proto, | |||||
| Dictionary<string, MapField<string, AttrValue>> operation_attributes) | |||||
| { | { | ||||
| return (null, null); | |||||
| return (new RestoredResource(), null); | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,11 +1,22 @@ | |||||
| using Tensorflow.Train; | |||||
| using Google.Protobuf.Collections; | |||||
| using Tensorflow.Train; | |||||
| namespace Tensorflow.Trackables; | namespace Tensorflow.Trackables; | ||||
| public class TrackableConstant : Trackable | public class TrackableConstant : Trackable | ||||
| { | { | ||||
| public static (Trackable, Action<object, object, object>) deserialize_from_proto() | |||||
| Tensor _constant; | |||||
| public TrackableConstant(Tensor constant) | |||||
| { | { | ||||
| return (null, null); | |||||
| _constant = constant; | |||||
| } | |||||
| public static (Trackable, Action<object, object, object>) deserialize_from_proto(SavedObject object_proto, | |||||
| Dictionary<string, MapField<string, AttrValue>> operation_attributes) | |||||
| { | |||||
| var tensor_proto = operation_attributes[object_proto.Constant.Operation]["value"].Tensor; | |||||
| var ndarray = tensor_util.MakeNdarray(tensor_proto); | |||||
| var imported_constant = constant_op.constant(ndarray); | |||||
| return (new TrackableConstant(imported_constant), null); | |||||
| } | } | ||||
| } | } | ||||
| @@ -9,6 +9,19 @@ namespace Tensorflow.Training.Saving.SavedModel | |||||
| { | { | ||||
| public static class function_deserialization | public static class function_deserialization | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Creates a `Function` from a `SavedFunction`. | |||||
| /// </summary> | |||||
| /// <param name="saved_concrete_function"></param> | |||||
| /// <param name="concrete_functions"></param> | |||||
| /// <returns></returns> | |||||
| public static ConcreteFunction recreate_function(SavedFunction saved_concrete_function, | |||||
| IDictionary<string, ConcreteFunction> concrete_functions) | |||||
| { | |||||
| var function_spec = _deserialize_function_spec_as_nonmethod(saved_concrete_function.FunctionSpec); | |||||
| return null; | |||||
| } | |||||
| public static ConcreteFunction setup_bare_concrete_function(SavedBareConcreteFunction saved_bare_concrete_function, | public static ConcreteFunction setup_bare_concrete_function(SavedBareConcreteFunction saved_bare_concrete_function, | ||||
| IDictionary<string, ConcreteFunction> concrete_functions) | IDictionary<string, ConcreteFunction> concrete_functions) | ||||
| { | { | ||||
| @@ -387,13 +387,6 @@ namespace Tensorflow | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| // skip the function and concrete function. | |||||
| if(proto.KindCase == SavedObject.KindOneofCase.BareConcreteFunction || proto.KindCase == SavedObject.KindOneofCase.Function) | |||||
| { | |||||
| nodes[node_id] = null; | |||||
| node_setters[node_id] = null; | |||||
| continue; | |||||
| } | |||||
| var (node, setter) = _recreate(proto, node_id, nodes); | var (node, setter) = _recreate(proto, node_id, nodes); | ||||
| nodes[node_id] = node; | nodes[node_id] = node; | ||||
| node_setters[node_id] = setter; | node_setters[node_id] = setter; | ||||
| @@ -471,6 +464,11 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| private void _setup_function_captures() | |||||
| { | |||||
| // TODO: implement it with concrete functions. | |||||
| } | |||||
| private void _setup_remaining_functions() | private void _setup_remaining_functions() | ||||
| { | { | ||||
| // TODO: implement it with concrete functions. | // TODO: implement it with concrete functions. | ||||
| @@ -542,9 +540,9 @@ namespace Tensorflow | |||||
| return proto.KindCase switch | return proto.KindCase switch | ||||
| { | { | ||||
| SavedObject.KindOneofCase.Resource => RestoredResource.deserialize_from_proto(), | |||||
| SavedObject.KindOneofCase.Asset => Asset.deserialize_from_proto(), | |||||
| SavedObject.KindOneofCase.Constant => TrackableConstant.deserialize_from_proto(), | |||||
| SavedObject.KindOneofCase.Resource => RestoredResource.deserialize_from_proto(proto, _operation_attributes), | |||||
| SavedObject.KindOneofCase.Asset => AssetResource.deserialize_from_proto(proto, _export_dir, _asset_file_def, _operation_attributes), | |||||
| SavedObject.KindOneofCase.Constant => TrackableConstant.deserialize_from_proto(proto, _operation_attributes), | |||||
| _ => _recreate_default(proto, node_id, dependencies) | _ => _recreate_default(proto, node_id, dependencies) | ||||
| }; | }; | ||||
| } | } | ||||
| @@ -563,7 +561,8 @@ namespace Tensorflow | |||||
| SavedObject.KindOneofCase.Function => _recreate_function(proto.Function, null), | SavedObject.KindOneofCase.Function => _recreate_function(proto.Function, null), | ||||
| SavedObject.KindOneofCase.BareConcreteFunction => throw new NotImplementedException(), | SavedObject.KindOneofCase.BareConcreteFunction => throw new NotImplementedException(), | ||||
| SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable), | SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable), | ||||
| SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException() | |||||
| SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException(), | |||||
| _ => throw new NotImplementedException() | |||||
| }; | }; | ||||
| } | } | ||||
| @@ -623,8 +622,12 @@ namespace Tensorflow | |||||
| private (ConcreteFunction, Action<object, object, object>) _recreate_function(SavedFunction proto, | private (ConcreteFunction, Action<object, object, object>) _recreate_function(SavedFunction proto, | ||||
| Dictionary<Maybe<string, int>, Trackable> dependencies) | Dictionary<Maybe<string, int>, Trackable> dependencies) | ||||
| { | { | ||||
| throw new NotImplementedException(); | |||||
| //var fn = function_deserialization.setup_bare_concrete_function(proto, ) | |||||
| var fn = function_deserialization.recreate_function(proto, null); | |||||
| foreach (var name in proto.ConcreteFunctions) | |||||
| { | |||||
| _setup_function_captures(); | |||||
| } | |||||
| return (fn, setattr); | |||||
| } | } | ||||
| private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, | private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, | ||||