| @@ -30,7 +30,7 @@ namespace Tensorflow.Checkpoint | |||
| ); | |||
| public static class SaveUtil | |||
| { | |||
| public static (IDictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||
| public static (IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||
| serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null) | |||
| { | |||
| var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); | |||
| @@ -119,16 +119,16 @@ namespace Tensorflow.Checkpoint | |||
| /// <param name="call_with_mapped_captures"></param> | |||
| /// <param name="cache"></param> | |||
| /// <param name="object_graph_proto"></param> | |||
| private static IDictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids, | |||
| private static IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids, | |||
| bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) | |||
| { | |||
| Dictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new(); | |||
| Dictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> serialized_tensors = new(); | |||
| foreach(var td in tensor_trackables) | |||
| { | |||
| // TODO: deal with cache. | |||
| var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; | |||
| Trackable trackable = null; | |||
| IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> tensor_dict; | |||
| IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_dict; | |||
| if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) | |||
| { | |||
| (trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); | |||
| @@ -150,12 +150,12 @@ namespace Tensorflow.Checkpoint | |||
| return serialized_tensors; | |||
| } | |||
| private static IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||
| private static IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||
| { | |||
| var trackable = trackable_data.object_to_save; | |||
| // TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. | |||
| IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> ret_tensor_dict; | |||
| IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> ret_tensor_dict; | |||
| if (call_with_mapped_captures) | |||
| { | |||
| throw new NotImplementedException(); | |||
| @@ -165,8 +165,7 @@ namespace Tensorflow.Checkpoint | |||
| ret_tensor_dict = trackable.serialize_to_tensors(); | |||
| } | |||
| // TODO: deal with the type `SaveSpce` (currently it will never be it). | |||
| Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> tensor_dict = new(); | |||
| Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_dict = new(); | |||
| foreach(var pair in ret_tensor_dict) | |||
| { | |||
| var local_name = TrackableUtils.escape_local_name(pair.Key); | |||
| @@ -175,10 +174,12 @@ namespace Tensorflow.Checkpoint | |||
| tensor_dict[checkpoint_key] = maybe_tensor; | |||
| if(maybe_tensor.IsTypeOrDeriveFrom<SaveSpec>()) | |||
| foreach(var key in maybe_tensor.Keys) | |||
| { | |||
| throw new NotImplementedException(); | |||
| //((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; | |||
| if (maybe_tensor[key].IsTypeOrDeriveFrom<SaveSpec>()) | |||
| { | |||
| maybe_tensor[key].AsT1.name = local_name + maybe_tensor[key].AsT1.name; | |||
| } | |||
| } | |||
| if(object_graph_proto is not null) | |||
| @@ -202,7 +203,7 @@ namespace Tensorflow.Checkpoint | |||
| /// <param name="call_with_mapped_captures"></param> | |||
| /// <param name="object_graph_proto"></param> | |||
| /// <returns></returns> | |||
| private static (Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids, | |||
| private static (Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids, | |||
| bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||
| { | |||
| Dictionary<Trackable, string> object_names = new(); | |||
| @@ -45,12 +45,12 @@ public class TrackableSaver | |||
| _graph_view = graph_view; | |||
| // TODO: cache when not executing eagerly. | |||
| // including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`, | |||
| // including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder` | |||
| // `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache` | |||
| } | |||
| private (IDictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||
| private (IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||
| gather_serialized_tensors(Tensor? object_graph_tensor = null) | |||
| { | |||
| var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); | |||
| @@ -69,9 +69,10 @@ public class TrackableSaver | |||
| Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | |||
| if (!serialized_tensors.ContainsKey(Trackable.None)) | |||
| { | |||
| serialized_tensors[Trackable.None] = new Dictionary<string, OneOf.OneOf<Tensor, IDictionary<string, Tensor>>>(); | |||
| serialized_tensors[Trackable.None] = new Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>(); | |||
| } | |||
| serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor; | |||
| serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = new Dictionary<string, OneOf<Tensor, SaveSpec>>(); | |||
| serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY].Add(saveable_object_util.NO_SLICE_SPEC_KEY, object_graph_tensor); | |||
| return (serialized_tensors, feed_additions, registered_savers, graph_proto); | |||
| } | |||
| @@ -387,6 +388,7 @@ public class CheckpointRestoreCoordinator | |||
| /// </summary> | |||
| public List<Trackable> AllTrackables => _all_trackables; | |||
| public HashSet<int> MatchedProtoIds => _matched_proto_ids; | |||
| // TODO(Rinne): change to weak ref. | |||
| public Dictionary<int, Trackable> ObjectByProtoId => _object_by_proto_id; | |||
| public int RestoreUid => _restore_uid; | |||
| public TrackableObjectGraph ObjectGraphProto => _object_graph_proto; | |||
| @@ -160,12 +160,12 @@ namespace Tensorflow.Checkpoint | |||
| /// <param name="serialized_tensors"> A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. </param> | |||
| /// <param name="registered_savers"></param> | |||
| /// <param name="call_with_mapped_capture"></param> | |||
| public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> serialized_tensors, | |||
| public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> serialized_tensors, | |||
| IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false) | |||
| { | |||
| _keys_to_restore_fn = new Dictionary<(string, string), RestoreFunc>(); | |||
| _restore_fn_to_keys = new Dictionary<RestoreFunc, IList<(string, string)>>(); | |||
| Dictionary<string, IDictionary<string, IDictionary<string, Tensor>>> tensors_by_device= new(); | |||
| Dictionary<string, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> tensors_by_device= new(); | |||
| foreach(var pair in serialized_tensors) | |||
| { | |||
| @@ -191,16 +191,7 @@ namespace Tensorflow.Checkpoint | |||
| foreach(var item in tensor_dict) | |||
| { | |||
| var checkpoint_key = item.Key; | |||
| IDictionary<string, Tensor> spec_to_tensor; | |||
| if(item.Value.TryPickT0(out var t, out var dic)) | |||
| { | |||
| spec_to_tensor = new Dictionary<string, Tensor>(); | |||
| spec_to_tensor[""] = t; | |||
| } | |||
| else | |||
| { | |||
| spec_to_tensor = dic; | |||
| } | |||
| var spec_to_tensor = item.Value; | |||
| foreach(var spec in spec_to_tensor) | |||
| { | |||
| @@ -216,11 +207,19 @@ namespace Tensorflow.Checkpoint | |||
| _restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec)); | |||
| // skip the process of device name because lack of API. | |||
| var host_device = tensor.Device; | |||
| var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, Tensor>>()); | |||
| string host_device; | |||
| if (tensor.IsT0) | |||
| { | |||
| host_device = tensor.AsT0.Device; | |||
| } | |||
| else | |||
| { | |||
| host_device = tensor.AsT1.device; | |||
| } | |||
| var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>()); | |||
| if (!internal_dict.ContainsKey(checkpoint_key)) | |||
| { | |||
| internal_dict[checkpoint_key] = new Dictionary<string, Tensor>(); | |||
| internal_dict[checkpoint_key] = new Dictionary<string, OneOf<Tensor, SaveSpec>>(); | |||
| } | |||
| internal_dict[checkpoint_key][slice_spec] = tensor; | |||
| } | |||
| @@ -425,7 +424,7 @@ namespace Tensorflow.Checkpoint | |||
| public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false) | |||
| { | |||
| Dictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new(); | |||
| Dictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> serialized_tensors = new(); | |||
| foreach (var saveable in saveables) | |||
| { | |||
| var trackable = new SaveableCompatibilityConverter(saveable, new List<MySaveableObject>() { saveable }); | |||
| @@ -3,6 +3,7 @@ using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using System.Security; | |||
| using System.Text; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training; | |||
| @@ -50,7 +51,7 @@ public class CheckpointPosition | |||
| { | |||
| _checkpoint.AllTrackables.Add(trackable); | |||
| _checkpoint.MatchedProtoIds.Add(_proto_id); | |||
| if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment)) | |||
| if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment) && current_assignment is not null) | |||
| { | |||
| // skip the `logging.warning`. | |||
| return false; | |||
| @@ -120,6 +120,11 @@ namespace Tensorflow.Contexts | |||
| name : | |||
| "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; | |||
| public string anonymous_name() | |||
| { | |||
| return "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; | |||
| } | |||
| public void graph_mode(bool isFunc = false) | |||
| => context_switches.Push(false, isFunc); | |||
| @@ -6,8 +6,11 @@ | |||
| public class DenseSpec : TypeSpec | |||
| { | |||
| protected Shape _shape; | |||
| public Shape shape => _shape; | |||
| public Shape shape | |||
| { | |||
| get { return _shape; } | |||
| set { _shape = value; } | |||
| } | |||
| protected TF_DataType _dtype; | |||
| public TF_DataType dtype => _dtype; | |||
| @@ -311,7 +311,7 @@ namespace Tensorflow | |||
| /// <param name="types">const TF_DataType*</param> | |||
| /// <param name="status">TF_Status*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_GraphSetOutputHandleShapesAndTypes(IntPtr graph, TF_Output output, | |||
| public static extern void TF_GraphSetOutputHandleShapesAndTypes(SafeGraphHandle graph, TF_Output output, | |||
| int num_shapes_and_types, IntPtr[] shapes, int[] ranks, DataType[] types, | |||
| SafeStatusHandle status); | |||
| @@ -30,6 +30,18 @@ namespace Tensorflow.Operations | |||
| } | |||
| } | |||
| public static HandleData create_handle_data(Shape shape, TF_DataType dtype) | |||
| { | |||
| HandleData handle_data = new(); | |||
| handle_data.IsSet = true; | |||
| handle_data.ShapeAndType.Add(new HandleShapeAndType() | |||
| { | |||
| Shape = shape.as_proto(), | |||
| Dtype = dtype.as_datatype_enum() | |||
| }); | |||
| return handle_data; | |||
| } | |||
| public static void set_handle_data(Tensor target_t, HandleData handle_data) | |||
| { | |||
| if(target_t is EagerTensor) | |||
| @@ -37,7 +49,8 @@ namespace Tensorflow.Operations | |||
| target_t.HandleData = handle_data; | |||
| return; | |||
| } | |||
| c_api.SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), handle_data.ToByteArray()); | |||
| // TODO(Rinne): enable it. (currently the internal c api cannot be invoked.) | |||
| //c_api.SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), handle_data.ToByteArray()); | |||
| } | |||
| } | |||
| } | |||
| @@ -21,6 +21,9 @@ using Tensorflow.Train; | |||
| using Tensorflow.Training.Saving.SavedModel; | |||
| using Tensorflow.Variables; | |||
| using static Tensorflow.CppShapeInferenceResult.Types; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow.Operations; | |||
| using System.Buffers; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -31,6 +34,7 @@ namespace Tensorflow | |||
| { | |||
| public static Operation shape_safe_assign_variable_handle(Tensor handle, int[] shape, Tensor value, string name = null) | |||
| { | |||
| // TODO(Rinne): deal with `_handle_graph`. | |||
| var value_tensor = ops.convert_to_tensor(value); | |||
| return gen_resource_variable_ops.assign_variable_op(handle, | |||
| value_tensor, | |||
| @@ -78,6 +82,18 @@ namespace Tensorflow | |||
| string shared_name, string name, bool graph_mode, Tensor initial_value = null) | |||
| { | |||
| var container = ops.get_default_graph().Container; | |||
| if(container is null) | |||
| { | |||
| container = ""; | |||
| } | |||
| if (!graph_mode) | |||
| { | |||
| if(shared_name is not null) | |||
| { | |||
| throw new Exception("Using an explicit shared_name is not allowed when executing eagerly."); | |||
| } | |||
| shared_name = tf.Context.anonymous_name(); | |||
| } | |||
| var handle = gen_resource_variable_ops.var_handle_op(shape: shape, | |||
| dtype: dtype, | |||
| shared_name: shared_name, | |||
| @@ -95,26 +111,20 @@ namespace Tensorflow | |||
| } | |||
| else | |||
| { | |||
| // We do not want two distinct ResourceVariable objects for the same | |||
| // underlying resource in the runtime. | |||
| // When in eager mode, explicitly ensure so here. When in graph mode, it's | |||
| // ensured by always generating different variable names. | |||
| var exists = gen_resource_variable_ops.var_is_initialized_op(handle); | |||
| // We create an assert Op instead of checking right away in order to be | |||
| // compatible with ASYNC execution mode. Further, since not all devices | |||
| // support string tensors, we encode the assertion string in the Op name | |||
| /*gen_logging_ops.assert(gen_math_ops.logical_not(exists), | |||
| new[] { exists }, | |||
| name: "EagerVariableNameReuse");*/ | |||
| var handle_data = new HandleData(); | |||
| handle_data.IsSet = true; | |||
| handle_data.ShapeAndType.Add(new HandleShapeAndType | |||
| var handle_data = handle_data_util.create_handle_data(shape, dtype); | |||
| if (initial_value is not null && initial_value.dtype == dtypes.variant) | |||
| { | |||
| Dtype = dtype.as_datatype_enum(), | |||
| Shape = shape.as_proto() | |||
| }); | |||
| var extra_handle_data = get_eager_safe_handle_data(initial_value); | |||
| if (extra_handle_data is not null && extra_handle_data.IsSet) | |||
| { | |||
| if (!handle_data.IsSet || handle_data.ShapeAndType.Count != 1) | |||
| { | |||
| throw new RuntimeError($"Expected VarHandleOp to return a length==1 shape_and_type, " + | |||
| $"but saw: '{handle_data}'"); | |||
| } | |||
| handle_data.ShapeAndType.AddRange(extra_handle_data.ShapeAndType); | |||
| } | |||
| } | |||
| _set_handle_shapes_and_types(handle, handle_data, graph_mode); | |||
| return handle; | |||
| } | |||
| @@ -126,24 +136,48 @@ namespace Tensorflow | |||
| /// <param name="handle"></param> | |||
| /// <param name="handle_data"></param> | |||
| /// <param name="graph_mode"></param> | |||
| internal static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode) | |||
| internal unsafe static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode) | |||
| { | |||
| tensor.HandleData = handle_data; | |||
| if (!graph_mode) | |||
| return; | |||
| var size = handle_data.ShapeAndType.Count; | |||
| //var shapes = handle_data.ShapeAndType.Select(x => x.Shape); | |||
| //var types = handle_data.ShapeAndType.Select(x => x.Dtype).ToArray(); | |||
| //var ranks = shapes.Select(s => s.UnknownRank ? -1 : s.Dim.Count).ToArray(); | |||
| //var converted_shapes = shapes.Select<TensorShapeProto, Memory<int>>(s => | |||
| //{ | |||
| // if (!s.UnknownRank) | |||
| // { | |||
| // return s.Dim.Select(d => (int)d.Size).ToArray(); | |||
| // } | |||
| // else | |||
| // { | |||
| // return Memory<int>.Empty; | |||
| // } | |||
| //}).ToArray(); | |||
| var shapes = new IntPtr[size]; | |||
| var types = new DataType[size]; | |||
| var ranks = new int[size]; | |||
| //List<MemoryHandle> handles = new(); | |||
| //IntPtr[] shapes_with_ptr = new IntPtr[converted_shapes.Length]; | |||
| //foreach(var (i, m) in enumerate(converted_shapes)) | |||
| //{ | |||
| // if(m.IsEmpty) | |||
| // { | |||
| // shapes_with_ptr[i] = IntPtr.Zero; | |||
| // } | |||
| // else | |||
| // { | |||
| // var handle = m.Pin(); | |||
| // handles.Add(handle); | |||
| // shapes_with_ptr[i] = new IntPtr(handle.Pointer); | |||
| // } | |||
| //} | |||
| for (int i = 0; i < size; i++) | |||
| { | |||
| var shapeAndType = handle_data.ShapeAndType[i]; | |||
| types[i] = shapeAndType.Dtype; | |||
| ranks[i] = shapeAndType.Shape.UnknownRank ? -1 : shapeAndType.Shape.Dim.Count; | |||
| var dims = shapeAndType.Shape.Dim.Select(x => x.Size).ToArray(); | |||
| } | |||
| //Status status = new(); | |||
| //// TODO(Rinne): enable it. | |||
| //c_api.TF_GraphSetOutputHandleShapesAndTypes(tensor.op.graph.c_graph, tensor._as_tf_output(), | |||
| // shapes_with_ptr.Length, shapes_with_ptr, ranks, types, status); | |||
| //handles = null; | |||
| } | |||
| /// <summary> | |||
| @@ -330,7 +330,7 @@ namespace Tensorflow { | |||
| private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_children_codec | |||
| = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); | |||
| private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_dependencies_codec | |||
| = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); | |||
| = pb::FieldCodec.ForMessage(122, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); | |||
| private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); | |||
| private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> dependencies_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); | |||
| /// <summary> | |||
| @@ -698,9 +698,13 @@ namespace Tensorflow { | |||
| break; | |||
| case 10: { | |||
| children_.AddEntriesFrom(input, _repeated_children_codec); | |||
| dependencies_.AddRange(children_.Except(dependencies_)); | |||
| break; | |||
| } | |||
| case 122: | |||
| { | |||
| dependencies_.AddEntriesFrom(input, _repeated_dependencies_codec); | |||
| break; | |||
| } | |||
| case 26: { | |||
| slotVariables_.AddEntriesFrom(input, _repeated_slotVariables_codec); | |||
| break; | |||
| @@ -3,6 +3,7 @@ using System.Linq; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.Keras.Saving.SavedModel; | |||
| using Tensorflow.Operations.Activation; | |||
| using Tensorflow.Training; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Train | |||
| @@ -25,6 +26,13 @@ namespace Tensorflow.Train | |||
| } | |||
| } | |||
| public override void SetAttr(string name, object value) | |||
| { | |||
| // TODO(Rinne): deal with `self_setattr_tracking`. | |||
| value = TrackableDataStructure.sticky_attribute_assignment(this, name, value); | |||
| base.SetAttr(name, value); | |||
| } | |||
| public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||
| { | |||
| if(save_type != SaveType.SAVEDMODEL) | |||
| @@ -34,6 +42,7 @@ namespace Tensorflow.Train | |||
| Dictionary<string, Trackable> functions = new(); | |||
| // TODO: process of logs. | |||
| // TODO(Rinne): deal with members. | |||
| var properties = this.GetType().GetProperties(); | |||
| foreach ( var property in properties ) | |||
| { | |||
| @@ -45,6 +54,16 @@ namespace Tensorflow.Train | |||
| } | |||
| } | |||
| foreach(var item in CustomizedFields) | |||
| { | |||
| var name = item.Key; | |||
| var value = item.Value; | |||
| if (value is Function or ConcreteFunction) | |||
| { | |||
| functions[name] = (Trackable)value; | |||
| } | |||
| } | |||
| // TODO: process the type `core_types.GenericFunction`. | |||
| Dictionary<string, Trackable> children = new(); | |||
| @@ -42,22 +42,25 @@ namespace Tensorflow | |||
| _var_device = var.Device; | |||
| _var_shape = var.shape; | |||
| Tensor _read_variable_closure(BaseResourceVariable v) | |||
| Func<Tensor> _read_variable_closure(BaseResourceVariable v) | |||
| { | |||
| tf.device(v.Device); | |||
| if(tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) | |||
| return () => | |||
| { | |||
| return null; | |||
| } | |||
| var x = v.read_value_no_copy(); | |||
| tf.device("/device:CPU:0"); | |||
| return array_ops.identity(x); | |||
| tf.device(v.Device); | |||
| if (tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) | |||
| { | |||
| return null; | |||
| } | |||
| var x = v.read_value_no_copy(); | |||
| tf.device("/device:CPU:0"); | |||
| return array_ops.identity(x); | |||
| }; | |||
| } | |||
| this.handle_op = var.Handle; | |||
| var tensor = _read_variable_closure(var); | |||
| var tensor_creator = _read_variable_closure(var); | |||
| var spec = new SaveSpec(tensor, slice_spec, name, dtype: var.dtype); | |||
| var spec = new SaveSpec(tensor_creator, slice_spec, name, dtype: var.dtype, device: var.Device); | |||
| _op = var; | |||
| specs = new SaveSpec[] { spec }; | |||
| this.name = name; | |||
| @@ -66,6 +69,7 @@ namespace Tensorflow | |||
| public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) | |||
| { | |||
| var restored_tensor = restored_tensors[0]; | |||
| tf.device(_var_device); | |||
| restored_tensor = array_ops.identity(restored_tensor); | |||
| return resource_variable_ops.shape_safe_assign_variable_handle( | |||
| handle_op, _var_shape, restored_tensor); | |||
| @@ -14,6 +14,8 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using Tensorflow.Exceptions; | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| @@ -21,8 +23,24 @@ namespace Tensorflow | |||
| /// </summary> | |||
| public class SaveSpec | |||
| { | |||
| private Tensor _tensor; | |||
| public Tensor tensor => _tensor; | |||
| private Tensor _tensor = null; | |||
| private Func<Tensor> _tensor_creator = null; | |||
| public Tensor tensor | |||
| { | |||
| get | |||
| { | |||
| if(_tensor is not null || _tensor_creator is null) | |||
| { | |||
| return _tensor; | |||
| } | |||
| else | |||
| { | |||
| return _tensor_creator(); | |||
| } | |||
| } | |||
| } | |||
| internal Func<Tensor> TensorCreator => _tensor_creator; | |||
| private string _slice_spec; | |||
| public string slice_spec => _slice_spec; | |||
| @@ -32,13 +50,36 @@ namespace Tensorflow | |||
| private TF_DataType _dtype; | |||
| public TF_DataType dtype => _dtype; | |||
| private string _device; | |||
| public string device => _device; | |||
| public SaveSpec(Tensor tensor, string slice_spec, string name, TF_DataType dtype = TF_DataType.DtInvalid) | |||
| public SaveSpec(Tensor tensor, string slice_spec, string name, TF_DataType dtype = TF_DataType.DtInvalid, string device = null) | |||
| { | |||
| _tensor = tensor; | |||
| _slice_spec = slice_spec; | |||
| _name = name; | |||
| _dtype = dtype; | |||
| if(device is not null) | |||
| { | |||
| _device = device; | |||
| } | |||
| else | |||
| { | |||
| _device = tensor.Device; | |||
| } | |||
| } | |||
| public SaveSpec(Func<Tensor> tensor_creator, string slice_spec, string name, TF_DataType dtype = TF_DataType.DtInvalid, string device = null) | |||
| { | |||
| _tensor_creator = tensor_creator; | |||
| _slice_spec = slice_spec; | |||
| _name = name; | |||
| if(dtype == TF_DataType.DtInvalid || device is null) | |||
| { | |||
| throw new AssertionError("When passing a callable `tensor` to a SaveSpec, an explicit dtype and device must be provided."); | |||
| } | |||
| _dtype = dtype; | |||
| _device = device; | |||
| } | |||
| } | |||
| } | |||
| @@ -1,10 +1,20 @@ | |||
| using System; | |||
| using System.Diagnostics; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training; | |||
| namespace Tensorflow; | |||
| public class RevivedTypes | |||
| { | |||
| private static Dictionary<string, ITrackableWrapper> _registered_revived_creator = new(); | |||
| static RevivedTypes() | |||
| { | |||
| var list_wrapper = new ListWrapper(new Trackable[] { }); | |||
| _registered_revived_creator[list_wrapper.Identifier] = list_wrapper; | |||
| var dict_wrapper = new DictWrapper(new Dictionary<object, Trackable>()); | |||
| _registered_revived_creator[dict_wrapper.Identifier] = dict_wrapper; | |||
| } | |||
| /// <summary> | |||
| /// Create a SavedUserObject from a trackable object. | |||
| /// </summary> | |||
| @@ -12,13 +22,28 @@ public class RevivedTypes | |||
| /// <returns></returns> | |||
| public static SavedUserObject? serialize(Trackable obj) | |||
| { | |||
| // TODO: complete the implementation. | |||
| // TODO(Rinne): complete the implementation. | |||
| return null; | |||
| } | |||
| public static Tuple<Trackable, Action<object, object, object>> deserialize(object proto) | |||
| public static (Trackable, Action<object, object, object>) deserialize(SavedUserObject proto) | |||
| { | |||
| // TODO: complete the implementation. | |||
| return null; | |||
| if(_registered_revived_creator.TryGetValue(proto.Identifier, out var wrapper)) | |||
| { | |||
| return (wrapper.FromProto(proto), (x, y, z) => | |||
| { | |||
| if (x is not ITrackableWrapper trackable) | |||
| { | |||
| throw new TypeError($"The type is expected to be `ITrackableWrapper`, but got {x.GetType()}."); | |||
| } | |||
| Debug.Assert(y is string); | |||
| trackable.SetValue(y, z); | |||
| } | |||
| ); | |||
| } | |||
| else | |||
| { | |||
| return (null, null); | |||
| } | |||
| } | |||
| } | |||
| @@ -49,6 +49,7 @@ namespace Tensorflow | |||
| var temp = _proto.ToString(); | |||
| _export_dir = export_dir; | |||
| // TODO: `this._concrete_functions` and `this._restored_concrete_functions` | |||
| // TODO(Rinne): This method is very slow, needs to be accelareted. | |||
| _concrete_functions = function_deserialization.load_function_def_library( | |||
| meta_graph.GraphDef.Library, _proto); | |||
| _restored_concrete_functions = new HashSet<string>(); | |||
| @@ -523,7 +524,7 @@ namespace Tensorflow | |||
| continue; | |||
| } | |||
| setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]); | |||
| // skip the process of "__call__" | |||
| // TODO(Rinne): deal with "__call__" | |||
| } | |||
| } | |||
| @@ -595,13 +596,12 @@ namespace Tensorflow | |||
| private (Trackable, Action<object, object, object>) _recreate_user_object(SavedUserObject? proto, int node_id) | |||
| { | |||
| // skip the check of proto identifier because of lack of property. | |||
| var looked_up = RevivedTypes.deserialize(proto); | |||
| if(looked_up is null) | |||
| var (trackable, setter) = RevivedTypes.deserialize(proto); | |||
| if(trackable is null) | |||
| { | |||
| return _recreate_base_user_object(proto, node_id); | |||
| } | |||
| return (looked_up.Item1, looked_up.Item2); | |||
| return (trackable, setter); | |||
| } | |||
| private (Trackable, Action<object, object, object>) _recreate_base_user_object(SavedUserObject? proto = null, int? node_id = null) | |||
| @@ -668,13 +668,20 @@ namespace Tensorflow | |||
| public static Action<object, object, object> setattr = (x, y, z) => | |||
| { | |||
| Debug.Assert(y is string); | |||
| var properties = x.GetType().GetProperties(); | |||
| foreach(var p in properties) | |||
| if(x is Trackable trackable) | |||
| { | |||
| trackable.SetAttr(y as string, z); | |||
| } | |||
| else | |||
| { | |||
| if((string)y == p.Name) | |||
| var properties = x.GetType().GetProperties(); | |||
| foreach (var p in properties) | |||
| { | |||
| p.SetValue(x, z); | |||
| return; | |||
| if ((string)y == p.Name) | |||
| { | |||
| p.SetValue(x, z); | |||
| return; | |||
| } | |||
| } | |||
| } | |||
| // TODO(Rinne): check if the property has been set successfully. | |||
| @@ -50,6 +50,10 @@ namespace Tensorflow | |||
| } | |||
| public static class saveable_object_util | |||
| { | |||
| public static string NO_SLICE_SPEC_KEY = ""; | |||
| private static HashSet<string> _VARIABLE_OPS = new HashSet<string>(new string[] { | |||
| "Variable", "VariableV2", "AutoReloadVariable", "VarHandleOp", "ReadVariableOp" | |||
| }); | |||
| /// <summary> | |||
| /// Returns the variables and names that will be used for a Saver. | |||
| /// </summary> | |||
| @@ -123,19 +127,12 @@ namespace Tensorflow | |||
| /// <returns></returns> | |||
| public static IEnumerable<MySaveableObject> saveable_objects_for_op(Tensor op, string name) | |||
| { | |||
| if (false) | |||
| { | |||
| } | |||
| ops.init_scope(); | |||
| var variable = ops.convert_to_tensor(op, as_ref: true); | |||
| if (variable.dtype.is_ref_dtype()) | |||
| yield return new ReferenceVariableSaveable(variable, "", name); | |||
| else | |||
| { | |||
| ops.init_scope(); | |||
| var variable = ops.convert_to_tensor(op, as_ref: true); | |||
| if (variable.dtype.is_ref_dtype()) | |||
| yield return new ReferenceVariableSaveable(variable, "", name); | |||
| else | |||
| yield return new ResourceVariableSaveable(variable, "", name); | |||
| } | |||
| yield return new ResourceVariableSaveable(variable, "", name); | |||
| } | |||
| /// <summary> | |||
| @@ -159,7 +156,7 @@ namespace Tensorflow | |||
| yield return new ResourceVariableSaveable(variable, "", name); | |||
| } | |||
| } | |||
| else | |||
| else if(obj is not IVariableV1) | |||
| { | |||
| foreach(var pair in saveable_objects_from_trackable(obj)) | |||
| { | |||
| @@ -191,6 +188,30 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| } | |||
| else | |||
| { | |||
| // Variable | |||
| if (tf.Context.executing_eagerly()) | |||
| { | |||
| throw new ValueError($"Can only save/restore ResourceVariables when " + | |||
| $"executing eagerly, got type: {obj.GetType()}."); | |||
| } | |||
| var variable = ops.convert_to_tensor(obj, as_ref: true); | |||
| if (!_tensor_comes_from_variable(variable)) | |||
| { | |||
| throw new TypeError($"names_to_saveables must be a dict mapping string " + | |||
| $"names to Tensors/Variables. Not a variable: {variable}"); | |||
| } | |||
| if(variable.op.type == "Variable" || variable.op.type == "VariableV2" || | |||
| variable.op.type == "AutoReloadVariable") | |||
| { | |||
| yield return new ReferenceVariableSaveable(variable, "", name); | |||
| } | |||
| else | |||
| { | |||
| yield return new ResourceVariableSaveable(variable, "", name); | |||
| } | |||
| } | |||
| } | |||
| /// <summary> | |||
| @@ -267,24 +288,14 @@ namespace Tensorflow | |||
| foreach (var pair in tensor_dict) | |||
| { | |||
| var tensor_name = pair.Key; | |||
| var maybe_tensor = pair.Value; | |||
| var internal_dict = pair.Value; | |||
| local_names.Add(tensor_name); | |||
| string spec_name = name + TrackableUtils.escape_local_name(tensor_name); | |||
| IDictionary<string, Tensor> internal_dict; | |||
| if (maybe_tensor.TryPickT0(out var tensor, out var dic)) | |||
| { | |||
| internal_dict = new Dictionary<string, Tensor>(); | |||
| internal_dict[""] = tensor; | |||
| } | |||
| else | |||
| { | |||
| internal_dict = dic; | |||
| } | |||
| foreach (var item in internal_dict) | |||
| { | |||
| specs.Add(new SaveSpec(item.Value, item.Key, spec_name)); | |||
| Debug.Assert(item.Value.IsT0); | |||
| specs.Add(new SaveSpec(item.Value.AsT0, item.Key, spec_name)); | |||
| } | |||
| } | |||
| return new TrackableSaveable(obj, specs, name, local_names, prefix); | |||
| @@ -316,9 +327,9 @@ namespace Tensorflow | |||
| /// Converts a list of SaveableObjects to a tensor dictionary. | |||
| /// </summary> | |||
| /// <param name="saveables"></param> | |||
| public static Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> saveable_object_to_tensor_dict(IList<MySaveableObject> saveables) | |||
| public static Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> saveable_object_to_tensor_dict(IList<MySaveableObject> saveables) | |||
| { | |||
| Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> tensor_dict = new(); | |||
| Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_dict = new(); | |||
| foreach (var saveable in saveables) | |||
| { | |||
| foreach (var spec in saveable.specs) | |||
| @@ -326,14 +337,11 @@ namespace Tensorflow | |||
| // skip the check that if `spec` is callable. | |||
| var name = convert_to_string(spec.name); | |||
| var slice_spec = convert_to_string(spec.slice_spec); | |||
| if (!string.IsNullOrEmpty(slice_spec)) | |||
| { | |||
| tensor_dict.SetDefault(name, new Dictionary<string, Tensor>()).AsT1[slice_spec] = spec.tensor; | |||
| } | |||
| else | |||
| if (string.IsNullOrEmpty(slice_spec)) | |||
| { | |||
| tensor_dict[name] = spec.tensor; | |||
| slice_spec = NO_SLICE_SPEC_KEY; | |||
| } | |||
| tensor_dict.SetDefault(name, new Dictionary<string, OneOf<Tensor, SaveSpec>>())[slice_spec] = spec.TensorCreator is null ? spec.tensor : spec; | |||
| } | |||
| } | |||
| return tensor_dict; | |||
| @@ -397,6 +405,11 @@ namespace Tensorflow | |||
| { | |||
| return factory(key); | |||
| } | |||
| private static bool _tensor_comes_from_variable(object v) | |||
| { | |||
| return v is Tensor tensor && _VARIABLE_OPS.Contains(tensor.op.type); | |||
| } | |||
| } | |||
| public class SaveableCompatibilityConverter: Trackable | |||
| @@ -412,7 +425,7 @@ namespace Tensorflow | |||
| public object Obj => _obj; | |||
| public IList<MySaveableObject> mySaveables=> _saveables; | |||
| public override IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors() | |||
| public override IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> serialize_to_tensors() | |||
| { | |||
| return saveable_object_util.saveable_object_to_tensor_dict(_saveables); | |||
| } | |||
| @@ -85,6 +85,72 @@ namespace Tensorflow.Train | |||
| _self_saveable_object_factories = value; | |||
| } | |||
| } | |||
| public Dictionary<string, object> CustomizedFields { get; set; } = new Dictionary<string, object>(); | |||
| public virtual void SetAttr(string name, object value) | |||
| { | |||
| var t = this.GetType(); | |||
| var field_info = t.GetField(name); | |||
| if(field_info is not null) | |||
| { | |||
| field_info.SetValue(this, value); | |||
| } | |||
| else | |||
| { | |||
| CustomizedFields[name] = value; | |||
| } | |||
| // On account of performance, we don't use reflection to set the attribute if it exists in `Trackable`. | |||
| // When adding new members or properties to this class, please add corresponding process to this method. | |||
| //switch (name) | |||
| //{ | |||
| // case "_manual_tracking": | |||
| // { | |||
| // _manual_tracking = (bool)value; | |||
| // break; | |||
| // } | |||
| // case "_self_saveable_object_factories": | |||
| // { | |||
| // _self_saveable_object_factories = (IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>>)value; | |||
| // break; | |||
| // } | |||
| // case "_self_update_uid": | |||
| // { | |||
| // _self_update_uid = (int)value; | |||
| // break; | |||
| // } | |||
| // case "_unconditional_checkpoint_dependencies": | |||
| // { | |||
| // _unconditional_checkpoint_dependencies = (IList<TrackableReference>)value; | |||
| // break; | |||
| // } | |||
| // case "_unconditional_deferred_dependencies": | |||
| // { | |||
| // _unconditional_deferred_dependencies = (Dictionary<string, IList<CheckpointPosition>>)value; | |||
| // break; | |||
| // } | |||
| // case "_unconditional_dependency_names": | |||
| // { | |||
| // _unconditional_dependency_names = (IDictionary<string, Trackable>)value; | |||
| // break; | |||
| // } | |||
| // case "SelfSaveableObjectFactories": | |||
| // { | |||
| // SelfSaveableObjectFactories = (IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>>)value; | |||
| // break; | |||
| // } | |||
| // case "UpdateUid": | |||
| // { | |||
| // UpdateUid = (int)value; | |||
| // break; | |||
| // } | |||
| // default: | |||
| // { | |||
| // CustomizedAttributes[name] = value; | |||
| // break; | |||
| // } | |||
| // } | |||
| } | |||
| /// <summary> | |||
| /// Restore-on-create for a variable be saved with this `Checkpointable`. | |||
| @@ -279,7 +345,7 @@ namespace Tensorflow.Train | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| /// <exception cref="NotImplementedException"></exception> | |||
| public virtual IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors() | |||
| public virtual IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> serialize_to_tensors() | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| @@ -2,6 +2,8 @@ | |||
| using System; | |||
| using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Diagnostics.CodeAnalysis; | |||
| using System.IO.Compression; | |||
| using System.Linq; | |||
| using System.Linq.Expressions; | |||
| @@ -25,6 +27,48 @@ namespace Tensorflow.Training | |||
| } | |||
| } | |||
| static class TrackableWrapperUtils | |||
| { | |||
| internal static bool ShouldLoad(ITrackableWrapper wrapper, SavedUserObject proto) | |||
| { | |||
| if (proto.Identifier != wrapper.Identifier) | |||
| { | |||
| return false; | |||
| } | |||
| if (wrapper.Version < proto.Version.MinConsumer) | |||
| { | |||
| return false; | |||
| } | |||
| if (proto.Version.Producer < wrapper.MinProducerVersion) | |||
| { | |||
| return false; | |||
| } | |||
| foreach (var bad_version in proto.Version.BadConsumers) | |||
| { | |||
| if (bad_version == wrapper.Version) | |||
| { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| internal static bool is_function(Trackable x) | |||
| { | |||
| return x is Function or ConcreteFunction; | |||
| } | |||
| } | |||
| public interface ITrackableWrapper | |||
| { | |||
| void SetValue(object name, object value); | |||
| String Identifier { get; } | |||
| int Version { get; } | |||
| int MinConsumerVersion { get; } | |||
| int MinProducerVersion { get; } | |||
| Trackable FromProto(SavedUserObject proto); | |||
| } | |||
| public abstract class TrackableDataStructure : Trackable | |||
| { | |||
| private bool _self_trainable; | |||
| @@ -36,7 +80,7 @@ namespace Tensorflow.Training | |||
| _self_extra_variables = new List<IVariableV1>(); | |||
| } | |||
| public abstract IEnumerable<Trackable> Values { get; } | |||
| public abstract ICollection<Trackable> Values { get; } | |||
| public bool Trainable { get => _self_trainable; set => _self_trainable = value; } | |||
| public IEnumerable<ILayer> Layers | |||
| { | |||
| @@ -134,7 +178,7 @@ namespace Tensorflow.Training | |||
| /// <param name="name"></param> | |||
| protected virtual Trackable _track_value(Trackable value, string name) | |||
| { | |||
| value = sticky_attribute_assignment(this, name, value); | |||
| value = (Trackable)sticky_attribute_assignment(this, name, value); | |||
| if(value is IVariableV1) | |||
| { | |||
| _self_extra_variables.Add(value as IVariableV1); | |||
| @@ -148,44 +192,273 @@ namespace Tensorflow.Training | |||
| return value.Value; | |||
| } | |||
| public static Trackable wrap_or_unwrap(Trackable value) | |||
| public static object wrap_or_unwrap(object value) | |||
| { | |||
| if(value is NoDependency dependency) | |||
| { | |||
| return dependency.Value; | |||
| } | |||
| if(value is Trackable trackable) | |||
| { | |||
| return trackable; | |||
| } | |||
| else if(value is IDictionary<object, Trackable> obj_dict) | |||
| { | |||
| return new DictWrapper(obj_dict); | |||
| } | |||
| else if(value is IList<Trackable> list) | |||
| { | |||
| return new ListWrapper(list); | |||
| } | |||
| else | |||
| { | |||
| return value; | |||
| } | |||
| } | |||
| public static object sticky_attribute_assignment(Trackable trackable, string name, object value) | |||
| { | |||
| bool add_dependency = value is not NoDependency; | |||
| value = wrap_or_unwrap(value); | |||
| if (!add_dependency) | |||
| { | |||
| return value; | |||
| } | |||
| if(value is Trackable trackable_obj) | |||
| { | |||
| trackable._track_trackable(trackable_obj, name, true); | |||
| } | |||
| return value; | |||
| } | |||
| } | |||
| // TODO(Rinne): Add Dict wrapper and Tuple wrapper | |||
| public class DictWrapper : TrackableDataStructure, IDictionary<object, Trackable>, ICloneable, ITrackableWrapper | |||
| { | |||
| private IDictionary<object, Trackable> _storage; | |||
| private bool _non_string_key; | |||
| private bool _external_modification; | |||
| private IDictionary<object, Trackable> _last_wrapped_dict_snapshot; | |||
| public DictWrapper(IDictionary<object, Trackable> wrapped_dict = null) | |||
| { | |||
| if(wrapped_dict is not null) | |||
| { | |||
| _storage = new Dictionary<object, Trackable>(wrapped_dict); | |||
| } | |||
| else | |||
| { | |||
| _storage = new Dictionary<object, Trackable>(); | |||
| } | |||
| _update_snapshot(); | |||
| } | |||
| public static Trackable wrap_or_unwrap(IList<Trackable> value) | |||
| public void SetValue(object name, object value) | |||
| { | |||
| return new ListWrapper(value); | |||
| Debug.Assert(value is Trackable); | |||
| this[name] = value as Trackable; | |||
| } | |||
| public String Identifier => "trackable_dict_wrapper"; | |||
| public int Version => 1; | |||
| public int MinConsumerVersion => 1; | |||
| public int MinProducerVersion => 1; | |||
| public Trackable FromProto(SavedUserObject proto) | |||
| { | |||
| return new DictWrapper(new Dictionary<object, Trackable>()); | |||
| } | |||
| public static Trackable wrap_or_unwrap(IEnumerable<Trackable> value) | |||
| public Trackable this[object key] | |||
| { | |||
| return new ListWrapper(value.ToList()); | |||
| get | |||
| { | |||
| return _storage[key]; | |||
| } | |||
| set | |||
| { | |||
| _check_self_external_modification(); | |||
| _maybe_initialize_trackable(); | |||
| bool no_dep = value is NoDependency; | |||
| if(key is string) | |||
| { | |||
| value = _track_value(value, key); | |||
| } | |||
| else | |||
| { | |||
| value = (Trackable)wrap_or_unwrap(value); | |||
| if(!no_dep && value is Trackable) | |||
| { | |||
| _non_string_key = true; | |||
| } | |||
| } | |||
| _storage[key] = value; | |||
| _update_snapshot(); | |||
| } | |||
| } | |||
| protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, Trackable value) | |||
| public ICollection<object> Keys => _storage.Keys; | |||
| public override ICollection<Trackable> Values => _storage.OrderBy(x => x.Key).Select(x => x.Value).ToArray(); | |||
| public void Add(object key, Trackable value) | |||
| { | |||
| value = wrap_or_unwrap(value); | |||
| trackable._track_trackable(value, name, true); | |||
| return value; | |||
| _storage[key] = value; | |||
| } | |||
| protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, NoDependency value) | |||
| public bool ContainsKey(object key) | |||
| { | |||
| var wrapped_value = wrap_or_unwrap(value); | |||
| trackable._track_trackable(wrapped_value, name, true); | |||
| return wrapped_value; | |||
| return _storage.ContainsKey(key); | |||
| } | |||
| protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, IList<Trackable> value) | |||
| public bool Remove(object key) | |||
| { | |||
| var wrapped_value = wrap_or_unwrap(value); | |||
| trackable._track_trackable(wrapped_value, name, true); | |||
| return wrapped_value; | |||
| _check_self_external_modification(); | |||
| var res = _storage.Remove(key); | |||
| _update_snapshot(); | |||
| return res; | |||
| } | |||
| } | |||
| public class ListWrapper : TrackableDataStructure, IList<Trackable>, ICloneable | |||
| public bool TryGetValue(object key, out Trackable value) | |||
| { | |||
| return _storage.TryGetValue(key, out value); | |||
| } | |||
| public int Count => _storage.Count; | |||
| public bool IsReadOnly => _storage.IsReadOnly; | |||
| public void Add(KeyValuePair<object, Trackable> item) | |||
| { | |||
| Add(item.Key, item.Value); | |||
| } | |||
| public void Clear() | |||
| { | |||
| _storage.Clear(); | |||
| _update_snapshot(); | |||
| } | |||
| public bool Contains(KeyValuePair<object, Trackable> item) | |||
| { | |||
| return _storage.Contains(item); | |||
| } | |||
| public void CopyTo(KeyValuePair<object, Trackable>[] array, int arrayIndex) | |||
| { | |||
| _storage.CopyTo(array, arrayIndex); | |||
| } | |||
| public bool Remove(KeyValuePair<object, Trackable> item) | |||
| { | |||
| _check_self_external_modification(); | |||
| var res = Remove(item); | |||
| _update_snapshot(); | |||
| return res; | |||
| } | |||
| public IEnumerator<KeyValuePair<object, Trackable>> GetEnumerator() | |||
| { | |||
| return _storage.GetEnumerator(); | |||
| } | |||
| IEnumerator IEnumerable.GetEnumerator() => _storage.GetEnumerator(); | |||
| public object Clone() | |||
| { | |||
| var copied = new DictWrapper(_storage); | |||
| copied._external_modification = _external_modification; | |||
| copied._non_string_key = _non_string_key; | |||
| return copied; | |||
| } | |||
| public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||
| { | |||
| _check_self_external_modification(); | |||
| if (_non_string_key) | |||
| { | |||
| throw new ValueError($"Unable to save the object {this} (a dictionary wrapper constructed \"" + | |||
| $"automatically on attribute assignment). The wrapped dictionary " + | |||
| $"contains a non-string key which maps to a trackable object or " + | |||
| $"mutable data structure.\n\nIf you don't need this dictionary " + | |||
| $"checkpointed, wrap it in a non-trackable " + | |||
| $"object; it will be subsequently ignored."); | |||
| } | |||
| if (_external_modification) | |||
| { | |||
| throw new ValueError($"Unable to save the object {this} (a dictionary wrapper constructed " + | |||
| $"automatically on attribute assignment). The wrapped dictionary was " + | |||
| $"modified outside the wrapper (its final value was {this}, its value" + | |||
| $" when a checkpoint dependency was added was " + | |||
| $"{this._last_wrapped_dict_snapshot}), which breaks " + | |||
| $"restoration on object creation.\n\nIf you don't need this " + | |||
| $"dictionary checkpointed, wrap it in a " + | |||
| $"non-trackable object; it will be subsequently ignored."); | |||
| } | |||
| Debug.Assert(!Dirty); | |||
| var children = base._trackable_children(save_type, cache); | |||
| if(save_type == SaveType.SAVEDMODEL) | |||
| { | |||
| foreach(var item in _storage) | |||
| { | |||
| var key = item.Key; | |||
| var value = item.Value; | |||
| if (TrackableWrapperUtils.is_function(value)) | |||
| { | |||
| Debug.Assert(key is string); | |||
| children[key as string] = value; | |||
| } | |||
| } | |||
| } | |||
| return children; | |||
| } | |||
| protected Trackable _track_value(Trackable value, object name) | |||
| { | |||
| bool string_key = name is string; | |||
| if (!string_key) | |||
| { | |||
| name = "-non_string_key"; | |||
| } | |||
| try | |||
| { | |||
| bool no_dependency = value is NoDependency; | |||
| value = base._track_value(value, name as string); | |||
| if(!(string_key || no_dependency)) | |||
| { | |||
| _non_string_key = true; | |||
| } | |||
| return value; | |||
| } | |||
| catch (ValueError) | |||
| { | |||
| return (Trackable)sticky_attribute_assignment(this, name as string, value); | |||
| } | |||
| } | |||
| private bool Dirty => _external_modification || _non_string_key; | |||
| private void _check_self_external_modification() | |||
| { | |||
| if (Dirty) | |||
| { | |||
| return; | |||
| } | |||
| if(!this._storage.SequenceEqual(_last_wrapped_dict_snapshot)) | |||
| { | |||
| _external_modification = true; | |||
| _last_wrapped_dict_snapshot = null; | |||
| } | |||
| } | |||
| private void _update_snapshot() | |||
| { | |||
| // TODO(Rinne): deal with attribute_sentinel. | |||
| if (Dirty) return; | |||
| _last_wrapped_dict_snapshot = new Dictionary<object, Trackable>(_storage); | |||
| } | |||
| } | |||
| public class ListWrapper : TrackableDataStructure, IList<Trackable>, ICloneable, ITrackableWrapper | |||
| { | |||
| private IList<Trackable> _storage; | |||
| private bool _non_append_mutation_value; | |||
| @@ -198,11 +471,51 @@ namespace Tensorflow.Training | |||
| /// modified directly after constructing the `ListWrapper`, and if changes are detected the `ListWrapper` will throw an exception on save.</param> | |||
| public ListWrapper(IList<Trackable> wrapped_list) | |||
| { | |||
| _storage = wrapped_list; | |||
| _storage = new List<Trackable>(wrapped_list); | |||
| _non_append_mutation_value = _external_modification_value = false; | |||
| _last_wrapped_list_snapshot = new List<Trackable>(_storage); | |||
| } | |||
| public string Identifier => "trackable_list_wrapper"; | |||
| public int Version => 1; | |||
| public int MinConsumerVersion => 1; | |||
| public int MinProducerVersion => 1; | |||
| public Trackable FromProto(SavedUserObject proto) | |||
| { | |||
| if(TrackableWrapperUtils.ShouldLoad(this, proto)) | |||
| { | |||
| return new ListWrapper(new Trackable[] { }); | |||
| } | |||
| else | |||
| { | |||
| return null; | |||
| } | |||
| } | |||
| public void SetValue(object name, object value) | |||
| { | |||
| Debug.Assert(name is string); | |||
| if(int.TryParse(name as string, out var index)) | |||
| { | |||
| if(value is not Trackable trackable) | |||
| { | |||
| throw new TypeError("Cannot set an object which is not trackable to ListWrapper."); | |||
| } | |||
| if(Count <= index) | |||
| { | |||
| Add(trackable); | |||
| } | |||
| else | |||
| { | |||
| this[index] = trackable; | |||
| } | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException("Encounter an unexpected behavior in <ListWrapper.SetAttr>, please " + | |||
| "submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
| } | |||
| } | |||
| protected bool NonAppendMuation { | |||
| get => _non_append_mutation_value; | |||
| set | |||
| @@ -222,7 +535,7 @@ namespace Tensorflow.Training | |||
| } | |||
| } | |||
| public override IEnumerable<Trackable> Values => this; | |||
| public override ICollection<Trackable> Values => this; | |||
| public bool IsReadOnly { get => _storage.IsReadOnly; } | |||
| /// <summary> | |||
| @@ -239,7 +552,7 @@ namespace Tensorflow.Training | |||
| private void update_snapshot() | |||
| { | |||
| // TODO: deal with `attribute_sentinel`. | |||
| // TODO(Rinne): deal with `attribute_sentinel`. | |||
| if (_external_modification_value || _non_append_mutation_value) return; | |||
| _last_wrapped_list_snapshot = new List<Trackable>(_storage); | |||
| } | |||
| @@ -286,9 +599,9 @@ namespace Tensorflow.Training | |||
| { | |||
| base._track_value(value, name); | |||
| } | |||
| catch(ValueError ex) | |||
| catch(ValueError) | |||
| { | |||
| value = sticky_attribute_assignment(this, name, value); | |||
| value = (Trackable)sticky_attribute_assignment(this, name, value); | |||
| } | |||
| return value; | |||
| } | |||
| @@ -343,7 +656,11 @@ namespace Tensorflow.Training | |||
| update_snapshot(); | |||
| } | |||
| public void Clear() => _storage.Clear(); | |||
| public void Clear() | |||
| { | |||
| _storage.Clear(); | |||
| update_snapshot(); | |||
| } | |||
| public bool Contains(Trackable item) => _storage.Contains(item); | |||
| @@ -519,6 +519,14 @@ namespace Tensorflow.Util | |||
| return pack_sequence_as(structure, mapped_flat_structure) as Tensor; | |||
| } | |||
| public static T2 map_structure<T1, T2>(Func<T1, T2> func, T1 structure) where T2: class | |||
| { | |||
| var flat_structure = flatten(structure); | |||
| var mapped_flat_structure = flat_structure.Select(func).Select(x => (object)x); | |||
| return pack_sequence_as(structure, mapped_flat_structure) as T2; | |||
| } | |||
| /// <summary> | |||
| /// Same as map_structure, but with only one structure (no combining of multiple structures) | |||
| /// </summary> | |||
| @@ -97,7 +97,7 @@ namespace Tensorflow | |||
| else | |||
| { | |||
| unique_id = $"{handle_name}_{ops.uid()}"; | |||
| shared_name = tf.Context.shared_name(); | |||
| shared_name = null; | |||
| } | |||
| var attr = new AttrValue(); | |||
| @@ -60,7 +60,15 @@ namespace Tensorflow.Keras | |||
| public void track_variable(IVariableV1 v) | |||
| { | |||
| if (tf.Context.executing_eagerly()) | |||
| { | |||
| return; | |||
| } | |||
| var graph = v.Graph; | |||
| if(graph is null) | |||
| { | |||
| graph = get_graph(); | |||
| } | |||
| _GRAPH_VARIABLES[graph.graph_key] = v; | |||
| } | |||
| @@ -21,10 +21,13 @@ using System.Linq; | |||
| using System.Threading; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Metrics; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.Keras.Utils; | |||
| using Tensorflow.NumPy; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.Engine | |||
| @@ -349,5 +352,59 @@ namespace Tensorflow.Keras.Engine | |||
| { | |||
| } | |||
| public override void SetAttr(string name, object value) | |||
| { | |||
| // TODO(Rinne): deal with "_self_setattr_tracking". | |||
| value = TrackableDataStructure.sticky_attribute_assignment(this, name, value); | |||
| foreach(var val in nest.flatten(value)) | |||
| { | |||
| if(val is Metric) | |||
| { | |||
| // TODO(Rinne): deal with metrics. | |||
| } | |||
| } | |||
| // TODO(Rinne): deal with "_auto_track_sub_layers". | |||
| foreach(var val in nest.flatten(value)) | |||
| { | |||
| if(val is not IVariableV1 variable) | |||
| { | |||
| continue; | |||
| } | |||
| if (variable.Trainable) | |||
| { | |||
| if (_trainable_weights.Contains(variable)) | |||
| { | |||
| continue; | |||
| } | |||
| _trainable_weights.Add(variable); | |||
| } | |||
| else | |||
| { | |||
| if (_non_trainable_weights.Contains(variable)) | |||
| { | |||
| continue; | |||
| } | |||
| _non_trainable_weights.Add(variable); | |||
| } | |||
| keras.backend.track_variable(variable); | |||
| } | |||
| // Directly use the implementation of `Trackable`. | |||
| var t = this.GetType(); | |||
| var field_info = t.GetField(name); | |||
| if (field_info is not null) | |||
| { | |||
| field_info.SetValue(this, value); | |||
| } | |||
| else | |||
| { | |||
| CustomizedFields[name] = value; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -1,7 +1,12 @@ | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using System.Diagnostics; | |||
| using Tensorflow.Framework.Models; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Losses; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.Keras.Saving.SavedModel; | |||
| using Tensorflow.Keras.Utils; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow.Keras.Engine | |||
| { | |||
| @@ -22,14 +27,16 @@ namespace Tensorflow.Keras.Engine | |||
| IOptimizer optimizer; | |||
| IVariableV1 _steps_per_execution; | |||
| protected bool _is_graph_network; | |||
| protected Tensors inputs; | |||
| public Tensors inputs; | |||
| protected Tensors outputs; | |||
| protected List<string> input_names; | |||
| public string[] output_names; | |||
| IVariableV1 _train_counter; | |||
| IVariableV1 _test_counter; | |||
| IVariableV1 _predict_counter; | |||
| bool _base_model_initialized; | |||
| bool stop_training; | |||
| TensorSpec _saved_model_inputs_spec; | |||
| public bool IsGraphNetwork => _is_graph_network; | |||
| @@ -45,6 +52,38 @@ namespace Tensorflow.Keras.Engine | |||
| _init_batch_counters(); | |||
| } | |||
| public void _set_inputs(TensorSpec inputs) | |||
| { | |||
| _set_save_spec(inputs); | |||
| } | |||
| internal void _set_save_spec(TensorSpec inputs) | |||
| { | |||
| if(_saved_model_inputs_spec is not null) | |||
| { | |||
| return; | |||
| } | |||
| var input_names = this.input_names; | |||
| if(input_names is null || input_names.Count == 0) | |||
| { | |||
| input_names = compile_utils.create_pseudo_input_names(inputs); | |||
| } | |||
| var flat_inputs = nest.flatten(inputs); | |||
| List<TensorSpec> specs = new(); | |||
| foreach(var (name, tensor) in zip(input_names, flat_inputs)) | |||
| { | |||
| specs.Add(tf_utils.get_tensor_spec(tensor, dynamic_batch: false, name: name)); | |||
| } | |||
| var packed_specs = nest.pack_sequence_as(inputs, specs) as TensorSpec; | |||
| Debug.Assert(specs is not null); | |||
| _saved_model_inputs_spec = packed_specs; | |||
| if(this is Sequential && _buildInputShape is null) | |||
| { | |||
| _buildInputShape = nest.map_structure<TensorSpec, TensorShapeConfig>(x => x is null ? null : x.shape, packed_specs); | |||
| } | |||
| } | |||
| internal override void Initialize(LayerArgs args) | |||
| { | |||
| _init_batch_counters(); | |||
| @@ -145,6 +184,16 @@ namespace Tensorflow.Keras.Engine | |||
| return children; | |||
| } | |||
| public override void SetAttr(string name, object value) | |||
| { | |||
| // TODO(Rinne): deal with "_self_setattr_tracking". | |||
| //if(nest.flatten(value).All(v => v is Layer or IVariableV1 || base_layer_utils.has_weights(v))) | |||
| //{ | |||
| // this._base_model_initialized; | |||
| //} | |||
| base.SetAttr(name, value); | |||
| } | |||
| void IModel.set_stopTraining_true() | |||
| { | |||
| @@ -1,12 +1,14 @@ | |||
| using Newtonsoft.Json; | |||
| using Newtonsoft.Json.Linq; | |||
| using System; | |||
| using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.ComponentModel; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using System.Reflection; | |||
| using System.Text.RegularExpressions; | |||
| using Tensorflow.Framework.Models; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Layers; | |||
| @@ -17,6 +19,8 @@ using Tensorflow.Keras.Saving.SavedModel; | |||
| using Tensorflow.Keras.Utils; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training; | |||
| using Tensorflow.Training.Saving.SavedModel; | |||
| using Tensorflow.Util; | |||
| using ThirdParty.Tensorflow.Python.Keras.Protobuf; | |||
| using static Tensorflow.ApiDef.Types; | |||
| using static Tensorflow.Binding; | |||
| @@ -190,12 +194,13 @@ namespace Tensorflow.Keras.Saving | |||
| Name = config["name"].ToObject<string>() | |||
| }); | |||
| //s.Name = config["name"].ToObject<string>(); | |||
| if(s.input is null || s.input.Length == 0) | |||
| if(s.inputs is null || s.inputs.Length == 0) | |||
| { | |||
| var first_layer = _get_child_layer_node_ids(model_id)[0]; | |||
| var input_specs = _infer_inputs(first_layer); | |||
| var input_shapes = _infer_inputs(first_layer, true); | |||
| var input_shapes = _infer_input_shapes(first_layer); | |||
| // `model._set_inputs(input_specs)` | |||
| s._set_inputs(input_specs); | |||
| // skip the check of input_specs is Dictionary | |||
| if (!s.Built) | |||
| @@ -220,12 +225,12 @@ namespace Tensorflow.Keras.Saving | |||
| private void _set_network_attributes_from_metadata(Model revived_object) | |||
| { | |||
| var metadata = revived_object.SerializedAttributes["matadata"] as JObject; | |||
| if (metadata.ContainsKey("dtype")) | |||
| var metadata = revived_object.SerializedAttributes["metadata"] as KerasMetaData; | |||
| if (metadata.DType != TF_DataType.DtInvalid) | |||
| { | |||
| // TODO(Rinne): set_dtype_policy. | |||
| } | |||
| revived_object.args.Trainable = metadata["trainable"].Value<bool>(); | |||
| revived_object.args.Trainable = metadata.Trainable; | |||
| } | |||
| /// <summary> | |||
| @@ -305,6 +310,11 @@ namespace Tensorflow.Keras.Saving | |||
| private (Trackable, Action<object, object, object>) _load_layer(int node_id, string identifier, string metadata_json) | |||
| { | |||
| var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); | |||
| // Debug(Rinne) | |||
| if(node_id == 11) | |||
| { | |||
| Console.WriteLine(); | |||
| } | |||
| if (loaded_nodes.ContainsKey(node_id)) | |||
| { | |||
| @@ -472,15 +482,7 @@ namespace Tensorflow.Keras.Saving | |||
| } | |||
| else | |||
| { | |||
| var properties = layer.GetType().GetProperties(); | |||
| foreach(var p in properties) | |||
| { | |||
| if(p.Name == name as string && p.GetValue(layer) is not null) | |||
| { | |||
| return; | |||
| } | |||
| } | |||
| Loader.setattr(layer, name, value); | |||
| layer.SetAttr(name as string, value); | |||
| } | |||
| } | |||
| @@ -607,7 +609,7 @@ namespace Tensorflow.Keras.Saving | |||
| if(build_input_shape is null) | |||
| { | |||
| build_input_shape = _infer_inputs(node_id, convert_to_shapes: true); | |||
| build_input_shape = _infer_input_shapes(node_id); | |||
| } | |||
| if(build_input_shape is not null) | |||
| @@ -633,7 +635,7 @@ namespace Tensorflow.Keras.Saving | |||
| /// <param name="layer_node_id"></param> | |||
| /// <param name="convert_to_shapes"></param> | |||
| /// <returns></returns> | |||
| private Shape _infer_inputs(int layer_node_id, bool convert_to_shapes = false) | |||
| private TensorSpec _infer_inputs(int layer_node_id) | |||
| { | |||
| var call_fn_id = _search_for_child_node(layer_node_id, new string[] { "call_and_return_all_conditional_losses" }); | |||
| if(call_fn_id is null) | |||
| @@ -648,7 +650,22 @@ namespace Tensorflow.Keras.Saving | |||
| } | |||
| var call_fn_name = concrete_functions[0]; | |||
| var call_fn_proto = _proto.ConcreteFunctions[call_fn_name]; | |||
| throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||
| var structured_input_signature = nested_structure_coder.decode_proto(call_fn_proto.CanonicalizedInputSignature); | |||
| Debug.Assert(structured_input_signature is IEnumerable); | |||
| var first_enumerator = (structured_input_signature as IEnumerable).GetEnumerator(); | |||
| first_enumerator.MoveNext(); | |||
| var first = first_enumerator.Current; | |||
| Debug.Assert(first is IEnumerable); | |||
| var inputs_enumerator = (first as IEnumerable).GetEnumerator(); | |||
| inputs_enumerator.MoveNext(); | |||
| var inputs = inputs_enumerator.Current as TensorSpec; | |||
| return inputs; | |||
| } | |||
| private Shape _infer_input_shapes(int layer_node_id) | |||
| { | |||
| var inputs = _infer_inputs(layer_node_id); | |||
| return nest.map_structure(x => x.shape, inputs); | |||
| } | |||
| private int? _search_for_child_node(int parent_id, IEnumerable<string> path_to_child) | |||
| @@ -48,19 +48,7 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||
| } | |||
| else | |||
| { | |||
| var properties = layer.GetType().GetProperties(); | |||
| foreach (var p in properties) | |||
| { | |||
| if ((string)name == p.Name) | |||
| { | |||
| if(p.GetValue(layer) is not null) | |||
| { | |||
| return; | |||
| } | |||
| p.SetValue(layer, value); | |||
| return; | |||
| } | |||
| } | |||
| layer.SetAttr(name as string, value); | |||
| } | |||
| } | |||
| } | |||
| @@ -11,7 +11,7 @@ using Tensorflow.Keras.Optimizers; | |||
| using ThirdParty.Tensorflow.Python.Keras.Protobuf; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow.Training; | |||
| using System.Diagnostics; | |||
| namespace Tensorflow.Keras.Saving.SavedModel; | |||
| @@ -135,12 +135,17 @@ public partial class KerasSavedModelUtils | |||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||
| })); | |||
| var layers = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); | |||
| Dictionary<string, Trackable> res = new(); | |||
| res["variables"] = variables; | |||
| res["trainable_variables"] = trainable_variables; | |||
| res["non_trainable_variables"] = non_trainable_variables; | |||
| res["layers"] = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); | |||
| Debug.Assert(variables is Trackable); | |||
| Debug.Assert(trainable_variables is Trackable); | |||
| Debug.Assert(non_trainable_variables is Trackable); | |||
| Debug.Assert(layers is Trackable); | |||
| res["variables"] = variables as Trackable; | |||
| res["trainable_variables"] = trainable_variables as Trackable; | |||
| res["non_trainable_variables"] = non_trainable_variables as Trackable; | |||
| res["layers"] = layers as Trackable; | |||
| return res; | |||
| } | |||
| @@ -165,6 +165,14 @@ namespace Tensorflow.Keras.Utils | |||
| } | |||
| } | |||
| public static bool has_weights(object obj) | |||
| { | |||
| var obj_type = obj.GetType(); | |||
| return obj_type.GetField("trainable_weights") is not null && | |||
| obj_type.GetField("non_trainable_weights") is not null && | |||
| obj is not Type; | |||
| } | |||
| // recusive | |||
| static bool uses_keras_history(Tensor op_input) | |||
| { | |||
| @@ -0,0 +1,22 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Framework.Models; | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow.Keras.Utils | |||
| { | |||
| internal static class compile_utils | |||
| { | |||
| public static List<string> create_pseudo_input_names(TensorSpec inputs) | |||
| { | |||
| return _create_pseudo_names(inputs, "input_"); | |||
| } | |||
| private static List<string> _create_pseudo_names(TensorSpec tensors, string prefix) | |||
| { | |||
| // TODO(Rinne): align with tensorflow | |||
| return new List<string>() { $"{prefix}1" }; | |||
| } | |||
| } | |||
| } | |||
| @@ -17,6 +17,7 @@ | |||
| using System; | |||
| using System.Linq; | |||
| using Tensorflow.Framework; | |||
| using Tensorflow.Framework.Models; | |||
| namespace Tensorflow.Keras.Utils | |||
| { | |||
| @@ -69,5 +70,29 @@ namespace Tensorflow.Keras.Utils | |||
| false_fn: false_fn, | |||
| name: name); | |||
| } | |||
| public static TensorSpec get_tensor_spec(Tensor t, bool dynamic_batch = false, string name = null) | |||
| { | |||
| throw new NotImplementedException("The function is waited to be implemented in the future."); | |||
| } | |||
| public static TensorSpec get_tensor_spec(TensorSpec t, bool dynamic_batch = false, string name = null) | |||
| { | |||
| var spec = t; | |||
| if (!dynamic_batch) | |||
| { | |||
| return spec; | |||
| } | |||
| var dynamic_batch_spec = new TensorSpec(t.shape, t.dtype, t.name); | |||
| var shape = dynamic_batch_spec.shape; | |||
| if(shape.rank > 0) | |||
| { | |||
| var shape_list = shape.as_int_list(); | |||
| // TODO(Rinne): check if -1 is equivalent to None in python. | |||
| shape_list[0] = -1; | |||
| dynamic_batch_spec.shape = new Shape(shape_list); | |||
| } | |||
| return dynamic_batch_spec; | |||
| } | |||
| } | |||
| } | |||
| @@ -64,5 +64,8 @@ public class SequentialModelLoad | |||
| { | |||
| var model = tf.keras.models.load_model(@"C:\Work\tf.net\tf_test\python_func"); | |||
| model.summary(); | |||
| var x = tf.ones((2, 10)); | |||
| var y = model.Apply(x); | |||
| } | |||
| } | |||