| @@ -30,7 +30,7 @@ namespace Tensorflow.Checkpoint | |||||
| ); | ); | ||||
| public static class SaveUtil | public static class SaveUtil | ||||
| { | { | ||||
| public static (IDictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||||
| public static (IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||||
| serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null) | serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null) | ||||
| { | { | ||||
| var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); | var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); | ||||
| @@ -119,16 +119,16 @@ namespace Tensorflow.Checkpoint | |||||
| /// <param name="call_with_mapped_captures"></param> | /// <param name="call_with_mapped_captures"></param> | ||||
| /// <param name="cache"></param> | /// <param name="cache"></param> | ||||
| /// <param name="object_graph_proto"></param> | /// <param name="object_graph_proto"></param> | ||||
| private static IDictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids, | |||||
| private static IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids, | |||||
| bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) | bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) | ||||
| { | { | ||||
| Dictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new(); | |||||
| Dictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> serialized_tensors = new(); | |||||
| foreach(var td in tensor_trackables) | foreach(var td in tensor_trackables) | ||||
| { | { | ||||
| // TODO: deal with cache. | // TODO: deal with cache. | ||||
| var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; | var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; | ||||
| Trackable trackable = null; | Trackable trackable = null; | ||||
| IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> tensor_dict; | |||||
| IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_dict; | |||||
| if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) | if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) | ||||
| { | { | ||||
| (trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); | (trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); | ||||
| @@ -150,12 +150,12 @@ namespace Tensorflow.Checkpoint | |||||
| return serialized_tensors; | return serialized_tensors; | ||||
| } | } | ||||
| private static IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||||
| private static IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||||
| { | { | ||||
| var trackable = trackable_data.object_to_save; | var trackable = trackable_data.object_to_save; | ||||
| // TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. | // TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. | ||||
| IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> ret_tensor_dict; | |||||
| IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> ret_tensor_dict; | |||||
| if (call_with_mapped_captures) | if (call_with_mapped_captures) | ||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| @@ -165,8 +165,7 @@ namespace Tensorflow.Checkpoint | |||||
| ret_tensor_dict = trackable.serialize_to_tensors(); | ret_tensor_dict = trackable.serialize_to_tensors(); | ||||
| } | } | ||||
| // TODO: deal with the type `SaveSpce` (currently it will never be it). | |||||
| Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> tensor_dict = new(); | |||||
| Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_dict = new(); | |||||
| foreach(var pair in ret_tensor_dict) | foreach(var pair in ret_tensor_dict) | ||||
| { | { | ||||
| var local_name = TrackableUtils.escape_local_name(pair.Key); | var local_name = TrackableUtils.escape_local_name(pair.Key); | ||||
| @@ -175,10 +174,12 @@ namespace Tensorflow.Checkpoint | |||||
| tensor_dict[checkpoint_key] = maybe_tensor; | tensor_dict[checkpoint_key] = maybe_tensor; | ||||
| if(maybe_tensor.IsTypeOrDeriveFrom<SaveSpec>()) | |||||
| foreach(var key in maybe_tensor.Keys) | |||||
| { | { | ||||
| throw new NotImplementedException(); | |||||
| //((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; | |||||
| if (maybe_tensor[key].IsTypeOrDeriveFrom<SaveSpec>()) | |||||
| { | |||||
| maybe_tensor[key].AsT1.name = local_name + maybe_tensor[key].AsT1.name; | |||||
| } | |||||
| } | } | ||||
| if(object_graph_proto is not null) | if(object_graph_proto is not null) | ||||
| @@ -202,7 +203,7 @@ namespace Tensorflow.Checkpoint | |||||
| /// <param name="call_with_mapped_captures"></param> | /// <param name="call_with_mapped_captures"></param> | ||||
| /// <param name="object_graph_proto"></param> | /// <param name="object_graph_proto"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| private static (Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids, | |||||
| private static (Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids, | |||||
| bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | ||||
| { | { | ||||
| Dictionary<Trackable, string> object_names = new(); | Dictionary<Trackable, string> object_names = new(); | ||||
| @@ -45,12 +45,12 @@ public class TrackableSaver | |||||
| _graph_view = graph_view; | _graph_view = graph_view; | ||||
| // TODO: cache when not executing eagerly. | // TODO: cache when not executing eagerly. | ||||
| // including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`, | |||||
| // including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder` | |||||
| // `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache` | // `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache` | ||||
| } | } | ||||
| private (IDictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||||
| private (IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||||
| gather_serialized_tensors(Tensor? object_graph_tensor = null) | gather_serialized_tensors(Tensor? object_graph_tensor = null) | ||||
| { | { | ||||
| var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); | var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); | ||||
| @@ -69,9 +69,10 @@ public class TrackableSaver | |||||
| Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | ||||
| if (!serialized_tensors.ContainsKey(Trackable.None)) | if (!serialized_tensors.ContainsKey(Trackable.None)) | ||||
| { | { | ||||
| serialized_tensors[Trackable.None] = new Dictionary<string, OneOf.OneOf<Tensor, IDictionary<string, Tensor>>>(); | |||||
| serialized_tensors[Trackable.None] = new Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>(); | |||||
| } | } | ||||
| serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor; | |||||
| serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = new Dictionary<string, OneOf<Tensor, SaveSpec>>(); | |||||
| serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY].Add(saveable_object_util.NO_SLICE_SPEC_KEY, object_graph_tensor); | |||||
| return (serialized_tensors, feed_additions, registered_savers, graph_proto); | return (serialized_tensors, feed_additions, registered_savers, graph_proto); | ||||
| } | } | ||||
| @@ -387,6 +388,7 @@ public class CheckpointRestoreCoordinator | |||||
| /// </summary> | /// </summary> | ||||
| public List<Trackable> AllTrackables => _all_trackables; | public List<Trackable> AllTrackables => _all_trackables; | ||||
| public HashSet<int> MatchedProtoIds => _matched_proto_ids; | public HashSet<int> MatchedProtoIds => _matched_proto_ids; | ||||
| // TODO(Rinne): change to weak ref. | |||||
| public Dictionary<int, Trackable> ObjectByProtoId => _object_by_proto_id; | public Dictionary<int, Trackable> ObjectByProtoId => _object_by_proto_id; | ||||
| public int RestoreUid => _restore_uid; | public int RestoreUid => _restore_uid; | ||||
| public TrackableObjectGraph ObjectGraphProto => _object_graph_proto; | public TrackableObjectGraph ObjectGraphProto => _object_graph_proto; | ||||
| @@ -160,12 +160,12 @@ namespace Tensorflow.Checkpoint | |||||
| /// <param name="serialized_tensors"> A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. </param> | /// <param name="serialized_tensors"> A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. </param> | ||||
| /// <param name="registered_savers"></param> | /// <param name="registered_savers"></param> | ||||
| /// <param name="call_with_mapped_capture"></param> | /// <param name="call_with_mapped_capture"></param> | ||||
| public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> serialized_tensors, | |||||
| public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> serialized_tensors, | |||||
| IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false) | IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false) | ||||
| { | { | ||||
| _keys_to_restore_fn = new Dictionary<(string, string), RestoreFunc>(); | _keys_to_restore_fn = new Dictionary<(string, string), RestoreFunc>(); | ||||
| _restore_fn_to_keys = new Dictionary<RestoreFunc, IList<(string, string)>>(); | _restore_fn_to_keys = new Dictionary<RestoreFunc, IList<(string, string)>>(); | ||||
| Dictionary<string, IDictionary<string, IDictionary<string, Tensor>>> tensors_by_device= new(); | |||||
| Dictionary<string, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> tensors_by_device= new(); | |||||
| foreach(var pair in serialized_tensors) | foreach(var pair in serialized_tensors) | ||||
| { | { | ||||
| @@ -191,16 +191,7 @@ namespace Tensorflow.Checkpoint | |||||
| foreach(var item in tensor_dict) | foreach(var item in tensor_dict) | ||||
| { | { | ||||
| var checkpoint_key = item.Key; | var checkpoint_key = item.Key; | ||||
| IDictionary<string, Tensor> spec_to_tensor; | |||||
| if(item.Value.TryPickT0(out var t, out var dic)) | |||||
| { | |||||
| spec_to_tensor = new Dictionary<string, Tensor>(); | |||||
| spec_to_tensor[""] = t; | |||||
| } | |||||
| else | |||||
| { | |||||
| spec_to_tensor = dic; | |||||
| } | |||||
| var spec_to_tensor = item.Value; | |||||
| foreach(var spec in spec_to_tensor) | foreach(var spec in spec_to_tensor) | ||||
| { | { | ||||
| @@ -216,11 +207,19 @@ namespace Tensorflow.Checkpoint | |||||
| _restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec)); | _restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec)); | ||||
| // skip the process of device name because lack of API. | // skip the process of device name because lack of API. | ||||
| var host_device = tensor.Device; | |||||
| var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, Tensor>>()); | |||||
| string host_device; | |||||
| if (tensor.IsT0) | |||||
| { | |||||
| host_device = tensor.AsT0.Device; | |||||
| } | |||||
| else | |||||
| { | |||||
| host_device = tensor.AsT1.device; | |||||
| } | |||||
| var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>()); | |||||
| if (!internal_dict.ContainsKey(checkpoint_key)) | if (!internal_dict.ContainsKey(checkpoint_key)) | ||||
| { | { | ||||
| internal_dict[checkpoint_key] = new Dictionary<string, Tensor>(); | |||||
| internal_dict[checkpoint_key] = new Dictionary<string, OneOf<Tensor, SaveSpec>>(); | |||||
| } | } | ||||
| internal_dict[checkpoint_key][slice_spec] = tensor; | internal_dict[checkpoint_key][slice_spec] = tensor; | ||||
| } | } | ||||
| @@ -425,7 +424,7 @@ namespace Tensorflow.Checkpoint | |||||
| public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false) | public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false) | ||||
| { | { | ||||
| Dictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new(); | |||||
| Dictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> serialized_tensors = new(); | |||||
| foreach (var saveable in saveables) | foreach (var saveable in saveables) | ||||
| { | { | ||||
| var trackable = new SaveableCompatibilityConverter(saveable, new List<MySaveableObject>() { saveable }); | var trackable = new SaveableCompatibilityConverter(saveable, new List<MySaveableObject>() { saveable }); | ||||
| @@ -3,6 +3,7 @@ using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Diagnostics; | using System.Diagnostics; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Security; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Train; | using Tensorflow.Train; | ||||
| using Tensorflow.Training; | using Tensorflow.Training; | ||||
| @@ -50,7 +51,7 @@ public class CheckpointPosition | |||||
| { | { | ||||
| _checkpoint.AllTrackables.Add(trackable); | _checkpoint.AllTrackables.Add(trackable); | ||||
| _checkpoint.MatchedProtoIds.Add(_proto_id); | _checkpoint.MatchedProtoIds.Add(_proto_id); | ||||
| if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment)) | |||||
| if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment) && current_assignment is not null) | |||||
| { | { | ||||
| // skip the `logging.warning`. | // skip the `logging.warning`. | ||||
| return false; | return false; | ||||
| @@ -120,6 +120,11 @@ namespace Tensorflow.Contexts | |||||
| name : | name : | ||||
| "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; | "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; | ||||
| public string anonymous_name() | |||||
| { | |||||
| return "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; | |||||
| } | |||||
| public void graph_mode(bool isFunc = false) | public void graph_mode(bool isFunc = false) | ||||
| => context_switches.Push(false, isFunc); | => context_switches.Push(false, isFunc); | ||||
| @@ -6,8 +6,11 @@ | |||||
| public class DenseSpec : TypeSpec | public class DenseSpec : TypeSpec | ||||
| { | { | ||||
| protected Shape _shape; | protected Shape _shape; | ||||
| public Shape shape => _shape; | |||||
| public Shape shape | |||||
| { | |||||
| get { return _shape; } | |||||
| set { _shape = value; } | |||||
| } | |||||
| protected TF_DataType _dtype; | protected TF_DataType _dtype; | ||||
| public TF_DataType dtype => _dtype; | public TF_DataType dtype => _dtype; | ||||
| @@ -311,7 +311,7 @@ namespace Tensorflow | |||||
| /// <param name="types">const TF_DataType*</param> | /// <param name="types">const TF_DataType*</param> | ||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_GraphSetOutputHandleShapesAndTypes(IntPtr graph, TF_Output output, | |||||
| public static extern void TF_GraphSetOutputHandleShapesAndTypes(SafeGraphHandle graph, TF_Output output, | |||||
| int num_shapes_and_types, IntPtr[] shapes, int[] ranks, DataType[] types, | int num_shapes_and_types, IntPtr[] shapes, int[] ranks, DataType[] types, | ||||
| SafeStatusHandle status); | SafeStatusHandle status); | ||||
| @@ -30,6 +30,18 @@ namespace Tensorflow.Operations | |||||
| } | } | ||||
| } | } | ||||
| public static HandleData create_handle_data(Shape shape, TF_DataType dtype) | |||||
| { | |||||
| HandleData handle_data = new(); | |||||
| handle_data.IsSet = true; | |||||
| handle_data.ShapeAndType.Add(new HandleShapeAndType() | |||||
| { | |||||
| Shape = shape.as_proto(), | |||||
| Dtype = dtype.as_datatype_enum() | |||||
| }); | |||||
| return handle_data; | |||||
| } | |||||
| public static void set_handle_data(Tensor target_t, HandleData handle_data) | public static void set_handle_data(Tensor target_t, HandleData handle_data) | ||||
| { | { | ||||
| if(target_t is EagerTensor) | if(target_t is EagerTensor) | ||||
| @@ -37,7 +49,8 @@ namespace Tensorflow.Operations | |||||
| target_t.HandleData = handle_data; | target_t.HandleData = handle_data; | ||||
| return; | return; | ||||
| } | } | ||||
| c_api.SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), handle_data.ToByteArray()); | |||||
| // TODO(Rinne): enable it. (currently the internal c api cannot be invoked.) | |||||
| //c_api.SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), handle_data.ToByteArray()); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -21,6 +21,9 @@ using Tensorflow.Train; | |||||
| using Tensorflow.Training.Saving.SavedModel; | using Tensorflow.Training.Saving.SavedModel; | ||||
| using Tensorflow.Variables; | using Tensorflow.Variables; | ||||
| using static Tensorflow.CppShapeInferenceResult.Types; | using static Tensorflow.CppShapeInferenceResult.Types; | ||||
| using static Tensorflow.Binding; | |||||
| using Tensorflow.Operations; | |||||
| using System.Buffers; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -31,6 +34,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| public static Operation shape_safe_assign_variable_handle(Tensor handle, int[] shape, Tensor value, string name = null) | public static Operation shape_safe_assign_variable_handle(Tensor handle, int[] shape, Tensor value, string name = null) | ||||
| { | { | ||||
| // TODO(Rinne): deal with `_handle_graph`. | |||||
| var value_tensor = ops.convert_to_tensor(value); | var value_tensor = ops.convert_to_tensor(value); | ||||
| return gen_resource_variable_ops.assign_variable_op(handle, | return gen_resource_variable_ops.assign_variable_op(handle, | ||||
| value_tensor, | value_tensor, | ||||
| @@ -78,6 +82,18 @@ namespace Tensorflow | |||||
| string shared_name, string name, bool graph_mode, Tensor initial_value = null) | string shared_name, string name, bool graph_mode, Tensor initial_value = null) | ||||
| { | { | ||||
| var container = ops.get_default_graph().Container; | var container = ops.get_default_graph().Container; | ||||
| if(container is null) | |||||
| { | |||||
| container = ""; | |||||
| } | |||||
| if (!graph_mode) | |||||
| { | |||||
| if(shared_name is not null) | |||||
| { | |||||
| throw new Exception("Using an explicit shared_name is not allowed when executing eagerly."); | |||||
| } | |||||
| shared_name = tf.Context.anonymous_name(); | |||||
| } | |||||
| var handle = gen_resource_variable_ops.var_handle_op(shape: shape, | var handle = gen_resource_variable_ops.var_handle_op(shape: shape, | ||||
| dtype: dtype, | dtype: dtype, | ||||
| shared_name: shared_name, | shared_name: shared_name, | ||||
| @@ -95,26 +111,20 @@ namespace Tensorflow | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| // We do not want two distinct ResourceVariable objects for the same | |||||
| // underlying resource in the runtime. | |||||
| // When in eager mode, explicitly ensure so here. When in graph mode, it's | |||||
| // ensured by always generating different variable names. | |||||
| var exists = gen_resource_variable_ops.var_is_initialized_op(handle); | |||||
| // We create an assert Op instead of checking right away in order to be | |||||
| // compatible with ASYNC execution mode. Further, since not all devices | |||||
| // support string tensors, we encode the assertion string in the Op name | |||||
| /*gen_logging_ops.assert(gen_math_ops.logical_not(exists), | |||||
| new[] { exists }, | |||||
| name: "EagerVariableNameReuse");*/ | |||||
| var handle_data = new HandleData(); | |||||
| handle_data.IsSet = true; | |||||
| handle_data.ShapeAndType.Add(new HandleShapeAndType | |||||
| var handle_data = handle_data_util.create_handle_data(shape, dtype); | |||||
| if (initial_value is not null && initial_value.dtype == dtypes.variant) | |||||
| { | { | ||||
| Dtype = dtype.as_datatype_enum(), | |||||
| Shape = shape.as_proto() | |||||
| }); | |||||
| var extra_handle_data = get_eager_safe_handle_data(initial_value); | |||||
| if (extra_handle_data is not null && extra_handle_data.IsSet) | |||||
| { | |||||
| if (!handle_data.IsSet || handle_data.ShapeAndType.Count != 1) | |||||
| { | |||||
| throw new RuntimeError($"Expected VarHandleOp to return a length==1 shape_and_type, " + | |||||
| $"but saw: '{handle_data}'"); | |||||
| } | |||||
| handle_data.ShapeAndType.AddRange(extra_handle_data.ShapeAndType); | |||||
| } | |||||
| } | |||||
| _set_handle_shapes_and_types(handle, handle_data, graph_mode); | _set_handle_shapes_and_types(handle, handle_data, graph_mode); | ||||
| return handle; | return handle; | ||||
| } | } | ||||
| @@ -126,24 +136,48 @@ namespace Tensorflow | |||||
| /// <param name="handle"></param> | /// <param name="handle"></param> | ||||
| /// <param name="handle_data"></param> | /// <param name="handle_data"></param> | ||||
| /// <param name="graph_mode"></param> | /// <param name="graph_mode"></param> | ||||
| internal static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode) | |||||
| internal unsafe static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode) | |||||
| { | { | ||||
| tensor.HandleData = handle_data; | |||||
| if (!graph_mode) | if (!graph_mode) | ||||
| return; | return; | ||||
| var size = handle_data.ShapeAndType.Count; | |||||
| //var shapes = handle_data.ShapeAndType.Select(x => x.Shape); | |||||
| //var types = handle_data.ShapeAndType.Select(x => x.Dtype).ToArray(); | |||||
| //var ranks = shapes.Select(s => s.UnknownRank ? -1 : s.Dim.Count).ToArray(); | |||||
| //var converted_shapes = shapes.Select<TensorShapeProto, Memory<int>>(s => | |||||
| //{ | |||||
| // if (!s.UnknownRank) | |||||
| // { | |||||
| // return s.Dim.Select(d => (int)d.Size).ToArray(); | |||||
| // } | |||||
| // else | |||||
| // { | |||||
| // return Memory<int>.Empty; | |||||
| // } | |||||
| //}).ToArray(); | |||||
| var shapes = new IntPtr[size]; | |||||
| var types = new DataType[size]; | |||||
| var ranks = new int[size]; | |||||
| //List<MemoryHandle> handles = new(); | |||||
| //IntPtr[] shapes_with_ptr = new IntPtr[converted_shapes.Length]; | |||||
| //foreach(var (i, m) in enumerate(converted_shapes)) | |||||
| //{ | |||||
| // if(m.IsEmpty) | |||||
| // { | |||||
| // shapes_with_ptr[i] = IntPtr.Zero; | |||||
| // } | |||||
| // else | |||||
| // { | |||||
| // var handle = m.Pin(); | |||||
| // handles.Add(handle); | |||||
| // shapes_with_ptr[i] = new IntPtr(handle.Pointer); | |||||
| // } | |||||
| //} | |||||
| for (int i = 0; i < size; i++) | |||||
| { | |||||
| var shapeAndType = handle_data.ShapeAndType[i]; | |||||
| types[i] = shapeAndType.Dtype; | |||||
| ranks[i] = shapeAndType.Shape.UnknownRank ? -1 : shapeAndType.Shape.Dim.Count; | |||||
| var dims = shapeAndType.Shape.Dim.Select(x => x.Size).ToArray(); | |||||
| } | |||||
| //Status status = new(); | |||||
| //// TODO(Rinne): enable it. | |||||
| //c_api.TF_GraphSetOutputHandleShapesAndTypes(tensor.op.graph.c_graph, tensor._as_tf_output(), | |||||
| // shapes_with_ptr.Length, shapes_with_ptr, ranks, types, status); | |||||
| //handles = null; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -330,7 +330,7 @@ namespace Tensorflow { | |||||
| private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_children_codec | private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_children_codec | ||||
| = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); | = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); | ||||
| private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_dependencies_codec | private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_dependencies_codec | ||||
| = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); | |||||
| = pb::FieldCodec.ForMessage(122, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); | |||||
| private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); | private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); | ||||
| private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> dependencies_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); | private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> dependencies_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); | ||||
| /// <summary> | /// <summary> | ||||
| @@ -698,9 +698,13 @@ namespace Tensorflow { | |||||
| break; | break; | ||||
| case 10: { | case 10: { | ||||
| children_.AddEntriesFrom(input, _repeated_children_codec); | children_.AddEntriesFrom(input, _repeated_children_codec); | ||||
| dependencies_.AddRange(children_.Except(dependencies_)); | |||||
| break; | break; | ||||
| } | } | ||||
| case 122: | |||||
| { | |||||
| dependencies_.AddEntriesFrom(input, _repeated_dependencies_codec); | |||||
| break; | |||||
| } | |||||
| case 26: { | case 26: { | ||||
| slotVariables_.AddEntriesFrom(input, _repeated_slotVariables_codec); | slotVariables_.AddEntriesFrom(input, _repeated_slotVariables_codec); | ||||
| break; | break; | ||||
| @@ -3,6 +3,7 @@ using System.Linq; | |||||
| using Tensorflow.Functions; | using Tensorflow.Functions; | ||||
| using Tensorflow.Keras.Saving.SavedModel; | using Tensorflow.Keras.Saving.SavedModel; | ||||
| using Tensorflow.Operations.Activation; | using Tensorflow.Operations.Activation; | ||||
| using Tensorflow.Training; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Train | namespace Tensorflow.Train | ||||
| @@ -25,6 +26,13 @@ namespace Tensorflow.Train | |||||
| } | } | ||||
| } | } | ||||
| public override void SetAttr(string name, object value) | |||||
| { | |||||
| // TODO(Rinne): deal with `self_setattr_tracking`. | |||||
| value = TrackableDataStructure.sticky_attribute_assignment(this, name, value); | |||||
| base.SetAttr(name, value); | |||||
| } | |||||
| public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | ||||
| { | { | ||||
| if(save_type != SaveType.SAVEDMODEL) | if(save_type != SaveType.SAVEDMODEL) | ||||
| @@ -34,6 +42,7 @@ namespace Tensorflow.Train | |||||
| Dictionary<string, Trackable> functions = new(); | Dictionary<string, Trackable> functions = new(); | ||||
| // TODO: process of logs. | // TODO: process of logs. | ||||
| // TODO(Rinne): deal with members. | |||||
| var properties = this.GetType().GetProperties(); | var properties = this.GetType().GetProperties(); | ||||
| foreach ( var property in properties ) | foreach ( var property in properties ) | ||||
| { | { | ||||
| @@ -45,6 +54,16 @@ namespace Tensorflow.Train | |||||
| } | } | ||||
| } | } | ||||
| foreach(var item in CustomizedFields) | |||||
| { | |||||
| var name = item.Key; | |||||
| var value = item.Value; | |||||
| if (value is Function or ConcreteFunction) | |||||
| { | |||||
| functions[name] = (Trackable)value; | |||||
| } | |||||
| } | |||||
| // TODO: process the type `core_types.GenericFunction`. | // TODO: process the type `core_types.GenericFunction`. | ||||
| Dictionary<string, Trackable> children = new(); | Dictionary<string, Trackable> children = new(); | ||||
| @@ -42,22 +42,25 @@ namespace Tensorflow | |||||
| _var_device = var.Device; | _var_device = var.Device; | ||||
| _var_shape = var.shape; | _var_shape = var.shape; | ||||
| Tensor _read_variable_closure(BaseResourceVariable v) | |||||
| Func<Tensor> _read_variable_closure(BaseResourceVariable v) | |||||
| { | { | ||||
| tf.device(v.Device); | |||||
| if(tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) | |||||
| return () => | |||||
| { | { | ||||
| return null; | |||||
| } | |||||
| var x = v.read_value_no_copy(); | |||||
| tf.device("/device:CPU:0"); | |||||
| return array_ops.identity(x); | |||||
| tf.device(v.Device); | |||||
| if (tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) | |||||
| { | |||||
| return null; | |||||
| } | |||||
| var x = v.read_value_no_copy(); | |||||
| tf.device("/device:CPU:0"); | |||||
| return array_ops.identity(x); | |||||
| }; | |||||
| } | } | ||||
| this.handle_op = var.Handle; | this.handle_op = var.Handle; | ||||
| var tensor = _read_variable_closure(var); | |||||
| var tensor_creator = _read_variable_closure(var); | |||||
| var spec = new SaveSpec(tensor, slice_spec, name, dtype: var.dtype); | |||||
| var spec = new SaveSpec(tensor_creator, slice_spec, name, dtype: var.dtype, device: var.Device); | |||||
| _op = var; | _op = var; | ||||
| specs = new SaveSpec[] { spec }; | specs = new SaveSpec[] { spec }; | ||||
| this.name = name; | this.name = name; | ||||
| @@ -66,6 +69,7 @@ namespace Tensorflow | |||||
| public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) | public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) | ||||
| { | { | ||||
| var restored_tensor = restored_tensors[0]; | var restored_tensor = restored_tensors[0]; | ||||
| tf.device(_var_device); | |||||
| restored_tensor = array_ops.identity(restored_tensor); | restored_tensor = array_ops.identity(restored_tensor); | ||||
| return resource_variable_ops.shape_safe_assign_variable_handle( | return resource_variable_ops.shape_safe_assign_variable_handle( | ||||
| handle_op, _var_shape, restored_tensor); | handle_op, _var_shape, restored_tensor); | ||||
| @@ -14,6 +14,8 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using Tensorflow.Exceptions; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| @@ -21,8 +23,24 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public class SaveSpec | public class SaveSpec | ||||
| { | { | ||||
| private Tensor _tensor; | |||||
| public Tensor tensor => _tensor; | |||||
| private Tensor _tensor = null; | |||||
| private Func<Tensor> _tensor_creator = null; | |||||
| public Tensor tensor | |||||
| { | |||||
| get | |||||
| { | |||||
| if(_tensor is not null || _tensor_creator is null) | |||||
| { | |||||
| return _tensor; | |||||
| } | |||||
| else | |||||
| { | |||||
| return _tensor_creator(); | |||||
| } | |||||
| } | |||||
| } | |||||
| internal Func<Tensor> TensorCreator => _tensor_creator; | |||||
| private string _slice_spec; | private string _slice_spec; | ||||
| public string slice_spec => _slice_spec; | public string slice_spec => _slice_spec; | ||||
| @@ -32,13 +50,36 @@ namespace Tensorflow | |||||
| private TF_DataType _dtype; | private TF_DataType _dtype; | ||||
| public TF_DataType dtype => _dtype; | public TF_DataType dtype => _dtype; | ||||
| private string _device; | |||||
| public string device => _device; | |||||
| public SaveSpec(Tensor tensor, string slice_spec, string name, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
| public SaveSpec(Tensor tensor, string slice_spec, string name, TF_DataType dtype = TF_DataType.DtInvalid, string device = null) | |||||
| { | { | ||||
| _tensor = tensor; | _tensor = tensor; | ||||
| _slice_spec = slice_spec; | _slice_spec = slice_spec; | ||||
| _name = name; | _name = name; | ||||
| _dtype = dtype; | _dtype = dtype; | ||||
| if(device is not null) | |||||
| { | |||||
| _device = device; | |||||
| } | |||||
| else | |||||
| { | |||||
| _device = tensor.Device; | |||||
| } | |||||
| } | |||||
| public SaveSpec(Func<Tensor> tensor_creator, string slice_spec, string name, TF_DataType dtype = TF_DataType.DtInvalid, string device = null) | |||||
| { | |||||
| _tensor_creator = tensor_creator; | |||||
| _slice_spec = slice_spec; | |||||
| _name = name; | |||||
| if(dtype == TF_DataType.DtInvalid || device is null) | |||||
| { | |||||
| throw new AssertionError("When passing a callable `tensor` to a SaveSpec, an explicit dtype and device must be provided."); | |||||
| } | |||||
| _dtype = dtype; | |||||
| _device = device; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,10 +1,20 @@ | |||||
| using System; | using System; | ||||
| using System.Diagnostics; | |||||
| using Tensorflow.Train; | using Tensorflow.Train; | ||||
| using Tensorflow.Training; | |||||
| namespace Tensorflow; | namespace Tensorflow; | ||||
| public class RevivedTypes | public class RevivedTypes | ||||
| { | { | ||||
| private static Dictionary<string, ITrackableWrapper> _registered_revived_creator = new(); | |||||
| static RevivedTypes() | |||||
| { | |||||
| var list_wrapper = new ListWrapper(new Trackable[] { }); | |||||
| _registered_revived_creator[list_wrapper.Identifier] = list_wrapper; | |||||
| var dict_wrapper = new DictWrapper(new Dictionary<object, Trackable>()); | |||||
| _registered_revived_creator[dict_wrapper.Identifier] = dict_wrapper; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Create a SavedUserObject from a trackable object. | /// Create a SavedUserObject from a trackable object. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -12,13 +22,28 @@ public class RevivedTypes | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static SavedUserObject? serialize(Trackable obj) | public static SavedUserObject? serialize(Trackable obj) | ||||
| { | { | ||||
| // TODO: complete the implementation. | |||||
| // TODO(Rinne): complete the implementation. | |||||
| return null; | return null; | ||||
| } | } | ||||
| public static Tuple<Trackable, Action<object, object, object>> deserialize(object proto) | |||||
| public static (Trackable, Action<object, object, object>) deserialize(SavedUserObject proto) | |||||
| { | { | ||||
| // TODO: complete the implementation. | |||||
| return null; | |||||
| if(_registered_revived_creator.TryGetValue(proto.Identifier, out var wrapper)) | |||||
| { | |||||
| return (wrapper.FromProto(proto), (x, y, z) => | |||||
| { | |||||
| if (x is not ITrackableWrapper trackable) | |||||
| { | |||||
| throw new TypeError($"The type is expected to be `ITrackableWrapper`, but got {x.GetType()}."); | |||||
| } | |||||
| Debug.Assert(y is string); | |||||
| trackable.SetValue(y, z); | |||||
| } | |||||
| ); | |||||
| } | |||||
| else | |||||
| { | |||||
| return (null, null); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -49,6 +49,7 @@ namespace Tensorflow | |||||
| var temp = _proto.ToString(); | var temp = _proto.ToString(); | ||||
| _export_dir = export_dir; | _export_dir = export_dir; | ||||
| // TODO: `this._concrete_functions` and `this._restored_concrete_functions` | // TODO: `this._concrete_functions` and `this._restored_concrete_functions` | ||||
| // TODO(Rinne): This method is very slow, needs to be accelareted. | |||||
| _concrete_functions = function_deserialization.load_function_def_library( | _concrete_functions = function_deserialization.load_function_def_library( | ||||
| meta_graph.GraphDef.Library, _proto); | meta_graph.GraphDef.Library, _proto); | ||||
| _restored_concrete_functions = new HashSet<string>(); | _restored_concrete_functions = new HashSet<string>(); | ||||
| @@ -523,7 +524,7 @@ namespace Tensorflow | |||||
| continue; | continue; | ||||
| } | } | ||||
| setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]); | setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]); | ||||
| // skip the process of "__call__" | |||||
| // TODO(Rinne): deal with "__call__" | |||||
| } | } | ||||
| } | } | ||||
| @@ -595,13 +596,12 @@ namespace Tensorflow | |||||
| private (Trackable, Action<object, object, object>) _recreate_user_object(SavedUserObject? proto, int node_id) | private (Trackable, Action<object, object, object>) _recreate_user_object(SavedUserObject? proto, int node_id) | ||||
| { | { | ||||
| // skip the check of proto identifier because of lack of property. | // skip the check of proto identifier because of lack of property. | ||||
| var looked_up = RevivedTypes.deserialize(proto); | |||||
| if(looked_up is null) | |||||
| var (trackable, setter) = RevivedTypes.deserialize(proto); | |||||
| if(trackable is null) | |||||
| { | { | ||||
| return _recreate_base_user_object(proto, node_id); | return _recreate_base_user_object(proto, node_id); | ||||
| } | } | ||||
| return (looked_up.Item1, looked_up.Item2); | |||||
| return (trackable, setter); | |||||
| } | } | ||||
| private (Trackable, Action<object, object, object>) _recreate_base_user_object(SavedUserObject? proto = null, int? node_id = null) | private (Trackable, Action<object, object, object>) _recreate_base_user_object(SavedUserObject? proto = null, int? node_id = null) | ||||
| @@ -668,13 +668,20 @@ namespace Tensorflow | |||||
| public static Action<object, object, object> setattr = (x, y, z) => | public static Action<object, object, object> setattr = (x, y, z) => | ||||
| { | { | ||||
| Debug.Assert(y is string); | Debug.Assert(y is string); | ||||
| var properties = x.GetType().GetProperties(); | |||||
| foreach(var p in properties) | |||||
| if(x is Trackable trackable) | |||||
| { | |||||
| trackable.SetAttr(y as string, z); | |||||
| } | |||||
| else | |||||
| { | { | ||||
| if((string)y == p.Name) | |||||
| var properties = x.GetType().GetProperties(); | |||||
| foreach (var p in properties) | |||||
| { | { | ||||
| p.SetValue(x, z); | |||||
| return; | |||||
| if ((string)y == p.Name) | |||||
| { | |||||
| p.SetValue(x, z); | |||||
| return; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| // TODO(Rinne): check if the property has been set successfully. | // TODO(Rinne): check if the property has been set successfully. | ||||
| @@ -50,6 +50,10 @@ namespace Tensorflow | |||||
| } | } | ||||
| public static class saveable_object_util | public static class saveable_object_util | ||||
| { | { | ||||
| public static string NO_SLICE_SPEC_KEY = ""; | |||||
| private static HashSet<string> _VARIABLE_OPS = new HashSet<string>(new string[] { | |||||
| "Variable", "VariableV2", "AutoReloadVariable", "VarHandleOp", "ReadVariableOp" | |||||
| }); | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the variables and names that will be used for a Saver. | /// Returns the variables and names that will be used for a Saver. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -123,19 +127,12 @@ namespace Tensorflow | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static IEnumerable<MySaveableObject> saveable_objects_for_op(Tensor op, string name) | public static IEnumerable<MySaveableObject> saveable_objects_for_op(Tensor op, string name) | ||||
| { | { | ||||
| if (false) | |||||
| { | |||||
| } | |||||
| ops.init_scope(); | |||||
| var variable = ops.convert_to_tensor(op, as_ref: true); | |||||
| if (variable.dtype.is_ref_dtype()) | |||||
| yield return new ReferenceVariableSaveable(variable, "", name); | |||||
| else | else | ||||
| { | |||||
| ops.init_scope(); | |||||
| var variable = ops.convert_to_tensor(op, as_ref: true); | |||||
| if (variable.dtype.is_ref_dtype()) | |||||
| yield return new ReferenceVariableSaveable(variable, "", name); | |||||
| else | |||||
| yield return new ResourceVariableSaveable(variable, "", name); | |||||
| } | |||||
| yield return new ResourceVariableSaveable(variable, "", name); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -159,7 +156,7 @@ namespace Tensorflow | |||||
| yield return new ResourceVariableSaveable(variable, "", name); | yield return new ResourceVariableSaveable(variable, "", name); | ||||
| } | } | ||||
| } | } | ||||
| else | |||||
| else if(obj is not IVariableV1) | |||||
| { | { | ||||
| foreach(var pair in saveable_objects_from_trackable(obj)) | foreach(var pair in saveable_objects_from_trackable(obj)) | ||||
| { | { | ||||
| @@ -191,6 +188,30 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| else | |||||
| { | |||||
| // Variable | |||||
| if (tf.Context.executing_eagerly()) | |||||
| { | |||||
| throw new ValueError($"Can only save/restore ResourceVariables when " + | |||||
| $"executing eagerly, got type: {obj.GetType()}."); | |||||
| } | |||||
| var variable = ops.convert_to_tensor(obj, as_ref: true); | |||||
| if (!_tensor_comes_from_variable(variable)) | |||||
| { | |||||
| throw new TypeError($"names_to_saveables must be a dict mapping string " + | |||||
| $"names to Tensors/Variables. Not a variable: {variable}"); | |||||
| } | |||||
| if(variable.op.type == "Variable" || variable.op.type == "VariableV2" || | |||||
| variable.op.type == "AutoReloadVariable") | |||||
| { | |||||
| yield return new ReferenceVariableSaveable(variable, "", name); | |||||
| } | |||||
| else | |||||
| { | |||||
| yield return new ResourceVariableSaveable(variable, "", name); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -267,24 +288,14 @@ namespace Tensorflow | |||||
| foreach (var pair in tensor_dict) | foreach (var pair in tensor_dict) | ||||
| { | { | ||||
| var tensor_name = pair.Key; | var tensor_name = pair.Key; | ||||
| var maybe_tensor = pair.Value; | |||||
| var internal_dict = pair.Value; | |||||
| local_names.Add(tensor_name); | local_names.Add(tensor_name); | ||||
| string spec_name = name + TrackableUtils.escape_local_name(tensor_name); | string spec_name = name + TrackableUtils.escape_local_name(tensor_name); | ||||
| IDictionary<string, Tensor> internal_dict; | |||||
| if (maybe_tensor.TryPickT0(out var tensor, out var dic)) | |||||
| { | |||||
| internal_dict = new Dictionary<string, Tensor>(); | |||||
| internal_dict[""] = tensor; | |||||
| } | |||||
| else | |||||
| { | |||||
| internal_dict = dic; | |||||
| } | |||||
| foreach (var item in internal_dict) | foreach (var item in internal_dict) | ||||
| { | { | ||||
| specs.Add(new SaveSpec(item.Value, item.Key, spec_name)); | |||||
| Debug.Assert(item.Value.IsT0); | |||||
| specs.Add(new SaveSpec(item.Value.AsT0, item.Key, spec_name)); | |||||
| } | } | ||||
| } | } | ||||
| return new TrackableSaveable(obj, specs, name, local_names, prefix); | return new TrackableSaveable(obj, specs, name, local_names, prefix); | ||||
| @@ -316,9 +327,9 @@ namespace Tensorflow | |||||
| /// Converts a list of SaveableObjects to a tensor dictionary. | /// Converts a list of SaveableObjects to a tensor dictionary. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="saveables"></param> | /// <param name="saveables"></param> | ||||
| public static Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> saveable_object_to_tensor_dict(IList<MySaveableObject> saveables) | |||||
| public static Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> saveable_object_to_tensor_dict(IList<MySaveableObject> saveables) | |||||
| { | { | ||||
| Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> tensor_dict = new(); | |||||
| Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_dict = new(); | |||||
| foreach (var saveable in saveables) | foreach (var saveable in saveables) | ||||
| { | { | ||||
| foreach (var spec in saveable.specs) | foreach (var spec in saveable.specs) | ||||
| @@ -326,14 +337,11 @@ namespace Tensorflow | |||||
| // skip the check that if `spec` is callable. | // skip the check that if `spec` is callable. | ||||
| var name = convert_to_string(spec.name); | var name = convert_to_string(spec.name); | ||||
| var slice_spec = convert_to_string(spec.slice_spec); | var slice_spec = convert_to_string(spec.slice_spec); | ||||
| if (!string.IsNullOrEmpty(slice_spec)) | |||||
| { | |||||
| tensor_dict.SetDefault(name, new Dictionary<string, Tensor>()).AsT1[slice_spec] = spec.tensor; | |||||
| } | |||||
| else | |||||
| if (string.IsNullOrEmpty(slice_spec)) | |||||
| { | { | ||||
| tensor_dict[name] = spec.tensor; | |||||
| slice_spec = NO_SLICE_SPEC_KEY; | |||||
| } | } | ||||
| tensor_dict.SetDefault(name, new Dictionary<string, OneOf<Tensor, SaveSpec>>())[slice_spec] = spec.TensorCreator is null ? spec.tensor : spec; | |||||
| } | } | ||||
| } | } | ||||
| return tensor_dict; | return tensor_dict; | ||||
| @@ -397,6 +405,11 @@ namespace Tensorflow | |||||
| { | { | ||||
| return factory(key); | return factory(key); | ||||
| } | } | ||||
| private static bool _tensor_comes_from_variable(object v) | |||||
| { | |||||
| return v is Tensor tensor && _VARIABLE_OPS.Contains(tensor.op.type); | |||||
| } | |||||
| } | } | ||||
| public class SaveableCompatibilityConverter: Trackable | public class SaveableCompatibilityConverter: Trackable | ||||
| @@ -412,7 +425,7 @@ namespace Tensorflow | |||||
| public object Obj => _obj; | public object Obj => _obj; | ||||
| public IList<MySaveableObject> mySaveables=> _saveables; | public IList<MySaveableObject> mySaveables=> _saveables; | ||||
| public override IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors() | |||||
| public override IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> serialize_to_tensors() | |||||
| { | { | ||||
| return saveable_object_util.saveable_object_to_tensor_dict(_saveables); | return saveable_object_util.saveable_object_to_tensor_dict(_saveables); | ||||
| } | } | ||||
| @@ -85,6 +85,72 @@ namespace Tensorflow.Train | |||||
| _self_saveable_object_factories = value; | _self_saveable_object_factories = value; | ||||
| } | } | ||||
| } | } | ||||
| public Dictionary<string, object> CustomizedFields { get; set; } = new Dictionary<string, object>(); | |||||
| public virtual void SetAttr(string name, object value) | |||||
| { | |||||
| var t = this.GetType(); | |||||
| var field_info = t.GetField(name); | |||||
| if(field_info is not null) | |||||
| { | |||||
| field_info.SetValue(this, value); | |||||
| } | |||||
| else | |||||
| { | |||||
| CustomizedFields[name] = value; | |||||
| } | |||||
| // On account of performance, we don't use reflection to set the attribute if it exists in `Trackable`. | |||||
| // When adding new members or properties to this class, please add corresponding process to this method. | |||||
| //switch (name) | |||||
| //{ | |||||
| // case "_manual_tracking": | |||||
| // { | |||||
| // _manual_tracking = (bool)value; | |||||
| // break; | |||||
| // } | |||||
| // case "_self_saveable_object_factories": | |||||
| // { | |||||
| // _self_saveable_object_factories = (IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>>)value; | |||||
| // break; | |||||
| // } | |||||
| // case "_self_update_uid": | |||||
| // { | |||||
| // _self_update_uid = (int)value; | |||||
| // break; | |||||
| // } | |||||
| // case "_unconditional_checkpoint_dependencies": | |||||
| // { | |||||
| // _unconditional_checkpoint_dependencies = (IList<TrackableReference>)value; | |||||
| // break; | |||||
| // } | |||||
| // case "_unconditional_deferred_dependencies": | |||||
| // { | |||||
| // _unconditional_deferred_dependencies = (Dictionary<string, IList<CheckpointPosition>>)value; | |||||
| // break; | |||||
| // } | |||||
| // case "_unconditional_dependency_names": | |||||
| // { | |||||
| // _unconditional_dependency_names = (IDictionary<string, Trackable>)value; | |||||
| // break; | |||||
| // } | |||||
| // case "SelfSaveableObjectFactories": | |||||
| // { | |||||
| // SelfSaveableObjectFactories = (IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>>)value; | |||||
| // break; | |||||
| // } | |||||
| // case "UpdateUid": | |||||
| // { | |||||
| // UpdateUid = (int)value; | |||||
| // break; | |||||
| // } | |||||
| // default: | |||||
| // { | |||||
| // CustomizedAttributes[name] = value; | |||||
| // break; | |||||
| // } | |||||
| // } | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Restore-on-create for a variable be saved with this `Checkpointable`. | /// Restore-on-create for a variable be saved with this `Checkpointable`. | ||||
| @@ -279,7 +345,7 @@ namespace Tensorflow.Train | |||||
| /// </summary> | /// </summary> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| /// <exception cref="NotImplementedException"></exception> | /// <exception cref="NotImplementedException"></exception> | ||||
| public virtual IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors() | |||||
| public virtual IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> serialize_to_tensors() | |||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| } | } | ||||
| @@ -2,6 +2,8 @@ | |||||
| using System; | using System; | ||||
| using System.Collections; | using System.Collections; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Diagnostics; | |||||
| using System.Diagnostics.CodeAnalysis; | |||||
| using System.IO.Compression; | using System.IO.Compression; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Linq.Expressions; | using System.Linq.Expressions; | ||||
| @@ -25,6 +27,48 @@ namespace Tensorflow.Training | |||||
| } | } | ||||
| } | } | ||||
| static class TrackableWrapperUtils | |||||
| { | |||||
| internal static bool ShouldLoad(ITrackableWrapper wrapper, SavedUserObject proto) | |||||
| { | |||||
| if (proto.Identifier != wrapper.Identifier) | |||||
| { | |||||
| return false; | |||||
| } | |||||
| if (wrapper.Version < proto.Version.MinConsumer) | |||||
| { | |||||
| return false; | |||||
| } | |||||
| if (proto.Version.Producer < wrapper.MinProducerVersion) | |||||
| { | |||||
| return false; | |||||
| } | |||||
| foreach (var bad_version in proto.Version.BadConsumers) | |||||
| { | |||||
| if (bad_version == wrapper.Version) | |||||
| { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| internal static bool is_function(Trackable x) | |||||
| { | |||||
| return x is Function or ConcreteFunction; | |||||
| } | |||||
| } | |||||
| public interface ITrackableWrapper | |||||
| { | |||||
| void SetValue(object name, object value); | |||||
| String Identifier { get; } | |||||
| int Version { get; } | |||||
| int MinConsumerVersion { get; } | |||||
| int MinProducerVersion { get; } | |||||
| Trackable FromProto(SavedUserObject proto); | |||||
| } | |||||
| public abstract class TrackableDataStructure : Trackable | public abstract class TrackableDataStructure : Trackable | ||||
| { | { | ||||
| private bool _self_trainable; | private bool _self_trainable; | ||||
| @@ -36,7 +80,7 @@ namespace Tensorflow.Training | |||||
| _self_extra_variables = new List<IVariableV1>(); | _self_extra_variables = new List<IVariableV1>(); | ||||
| } | } | ||||
| public abstract IEnumerable<Trackable> Values { get; } | |||||
| public abstract ICollection<Trackable> Values { get; } | |||||
| public bool Trainable { get => _self_trainable; set => _self_trainable = value; } | public bool Trainable { get => _self_trainable; set => _self_trainable = value; } | ||||
| public IEnumerable<ILayer> Layers | public IEnumerable<ILayer> Layers | ||||
| { | { | ||||
| @@ -134,7 +178,7 @@ namespace Tensorflow.Training | |||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| protected virtual Trackable _track_value(Trackable value, string name) | protected virtual Trackable _track_value(Trackable value, string name) | ||||
| { | { | ||||
| value = sticky_attribute_assignment(this, name, value); | |||||
| value = (Trackable)sticky_attribute_assignment(this, name, value); | |||||
| if(value is IVariableV1) | if(value is IVariableV1) | ||||
| { | { | ||||
| _self_extra_variables.Add(value as IVariableV1); | _self_extra_variables.Add(value as IVariableV1); | ||||
| @@ -148,44 +192,273 @@ namespace Tensorflow.Training | |||||
| return value.Value; | return value.Value; | ||||
| } | } | ||||
| public static Trackable wrap_or_unwrap(Trackable value) | |||||
| public static object wrap_or_unwrap(object value) | |||||
| { | { | ||||
| if(value is NoDependency dependency) | |||||
| { | |||||
| return dependency.Value; | |||||
| } | |||||
| if(value is Trackable trackable) | |||||
| { | |||||
| return trackable; | |||||
| } | |||||
| else if(value is IDictionary<object, Trackable> obj_dict) | |||||
| { | |||||
| return new DictWrapper(obj_dict); | |||||
| } | |||||
| else if(value is IList<Trackable> list) | |||||
| { | |||||
| return new ListWrapper(list); | |||||
| } | |||||
| else | |||||
| { | |||||
| return value; | |||||
| } | |||||
| } | |||||
| public static object sticky_attribute_assignment(Trackable trackable, string name, object value) | |||||
| { | |||||
| bool add_dependency = value is not NoDependency; | |||||
| value = wrap_or_unwrap(value); | |||||
| if (!add_dependency) | |||||
| { | |||||
| return value; | |||||
| } | |||||
| if(value is Trackable trackable_obj) | |||||
| { | |||||
| trackable._track_trackable(trackable_obj, name, true); | |||||
| } | |||||
| return value; | return value; | ||||
| } | } | ||||
| } | |||||
| // TODO(Rinne): Add Dict wrapper and Tuple wrapper | |||||
| public class DictWrapper : TrackableDataStructure, IDictionary<object, Trackable>, ICloneable, ITrackableWrapper | |||||
| { | |||||
| private IDictionary<object, Trackable> _storage; | |||||
| private bool _non_string_key; | |||||
| private bool _external_modification; | |||||
| private IDictionary<object, Trackable> _last_wrapped_dict_snapshot; | |||||
| public DictWrapper(IDictionary<object, Trackable> wrapped_dict = null) | |||||
| { | |||||
| if(wrapped_dict is not null) | |||||
| { | |||||
| _storage = new Dictionary<object, Trackable>(wrapped_dict); | |||||
| } | |||||
| else | |||||
| { | |||||
| _storage = new Dictionary<object, Trackable>(); | |||||
| } | |||||
| _update_snapshot(); | |||||
| } | |||||
| public static Trackable wrap_or_unwrap(IList<Trackable> value) | |||||
| public void SetValue(object name, object value) | |||||
| { | { | ||||
| return new ListWrapper(value); | |||||
| Debug.Assert(value is Trackable); | |||||
| this[name] = value as Trackable; | |||||
| } | |||||
| public String Identifier => "trackable_dict_wrapper"; | |||||
| public int Version => 1; | |||||
| public int MinConsumerVersion => 1; | |||||
| public int MinProducerVersion => 1; | |||||
| public Trackable FromProto(SavedUserObject proto) | |||||
| { | |||||
| return new DictWrapper(new Dictionary<object, Trackable>()); | |||||
| } | } | ||||
| public static Trackable wrap_or_unwrap(IEnumerable<Trackable> value) | |||||
| public Trackable this[object key] | |||||
| { | { | ||||
| return new ListWrapper(value.ToList()); | |||||
| get | |||||
| { | |||||
| return _storage[key]; | |||||
| } | |||||
| set | |||||
| { | |||||
| _check_self_external_modification(); | |||||
| _maybe_initialize_trackable(); | |||||
| bool no_dep = value is NoDependency; | |||||
| if(key is string) | |||||
| { | |||||
| value = _track_value(value, key); | |||||
| } | |||||
| else | |||||
| { | |||||
| value = (Trackable)wrap_or_unwrap(value); | |||||
| if(!no_dep && value is Trackable) | |||||
| { | |||||
| _non_string_key = true; | |||||
| } | |||||
| } | |||||
| _storage[key] = value; | |||||
| _update_snapshot(); | |||||
| } | |||||
| } | } | ||||
| protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, Trackable value) | |||||
| public ICollection<object> Keys => _storage.Keys; | |||||
| public override ICollection<Trackable> Values => _storage.OrderBy(x => x.Key).Select(x => x.Value).ToArray(); | |||||
| public void Add(object key, Trackable value) | |||||
| { | { | ||||
| value = wrap_or_unwrap(value); | |||||
| trackable._track_trackable(value, name, true); | |||||
| return value; | |||||
| _storage[key] = value; | |||||
| } | } | ||||
| protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, NoDependency value) | |||||
| public bool ContainsKey(object key) | |||||
| { | { | ||||
| var wrapped_value = wrap_or_unwrap(value); | |||||
| trackable._track_trackable(wrapped_value, name, true); | |||||
| return wrapped_value; | |||||
| return _storage.ContainsKey(key); | |||||
| } | } | ||||
| protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, IList<Trackable> value) | |||||
| public bool Remove(object key) | |||||
| { | { | ||||
| var wrapped_value = wrap_or_unwrap(value); | |||||
| trackable._track_trackable(wrapped_value, name, true); | |||||
| return wrapped_value; | |||||
| _check_self_external_modification(); | |||||
| var res = _storage.Remove(key); | |||||
| _update_snapshot(); | |||||
| return res; | |||||
| } | } | ||||
| } | |||||
| public class ListWrapper : TrackableDataStructure, IList<Trackable>, ICloneable | |||||
| public bool TryGetValue(object key, out Trackable value) | |||||
| { | |||||
| return _storage.TryGetValue(key, out value); | |||||
| } | |||||
| public int Count => _storage.Count; | |||||
| public bool IsReadOnly => _storage.IsReadOnly; | |||||
| public void Add(KeyValuePair<object, Trackable> item) | |||||
| { | |||||
| Add(item.Key, item.Value); | |||||
| } | |||||
| public void Clear() | |||||
| { | |||||
| _storage.Clear(); | |||||
| _update_snapshot(); | |||||
| } | |||||
| public bool Contains(KeyValuePair<object, Trackable> item) | |||||
| { | |||||
| return _storage.Contains(item); | |||||
| } | |||||
| public void CopyTo(KeyValuePair<object, Trackable>[] array, int arrayIndex) | |||||
| { | |||||
| _storage.CopyTo(array, arrayIndex); | |||||
| } | |||||
| public bool Remove(KeyValuePair<object, Trackable> item) | |||||
| { | |||||
| _check_self_external_modification(); | |||||
| var res = Remove(item); | |||||
| _update_snapshot(); | |||||
| return res; | |||||
| } | |||||
| public IEnumerator<KeyValuePair<object, Trackable>> GetEnumerator() | |||||
| { | |||||
| return _storage.GetEnumerator(); | |||||
| } | |||||
| IEnumerator IEnumerable.GetEnumerator() => _storage.GetEnumerator(); | |||||
| public object Clone() | |||||
| { | |||||
| var copied = new DictWrapper(_storage); | |||||
| copied._external_modification = _external_modification; | |||||
| copied._non_string_key = _non_string_key; | |||||
| return copied; | |||||
| } | |||||
| public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||||
| { | |||||
| _check_self_external_modification(); | |||||
| if (_non_string_key) | |||||
| { | |||||
| throw new ValueError($"Unable to save the object {this} (a dictionary wrapper constructed \"" + | |||||
| $"automatically on attribute assignment). The wrapped dictionary " + | |||||
| $"contains a non-string key which maps to a trackable object or " + | |||||
| $"mutable data structure.\n\nIf you don't need this dictionary " + | |||||
| $"checkpointed, wrap it in a non-trackable " + | |||||
| $"object; it will be subsequently ignored."); | |||||
| } | |||||
| if (_external_modification) | |||||
| { | |||||
| throw new ValueError($"Unable to save the object {this} (a dictionary wrapper constructed " + | |||||
| $"automatically on attribute assignment). The wrapped dictionary was " + | |||||
| $"modified outside the wrapper (its final value was {this}, its value" + | |||||
| $" when a checkpoint dependency was added was " + | |||||
| $"{this._last_wrapped_dict_snapshot}), which breaks " + | |||||
| $"restoration on object creation.\n\nIf you don't need this " + | |||||
| $"dictionary checkpointed, wrap it in a " + | |||||
| $"non-trackable object; it will be subsequently ignored."); | |||||
| } | |||||
| Debug.Assert(!Dirty); | |||||
| var children = base._trackable_children(save_type, cache); | |||||
| if(save_type == SaveType.SAVEDMODEL) | |||||
| { | |||||
| foreach(var item in _storage) | |||||
| { | |||||
| var key = item.Key; | |||||
| var value = item.Value; | |||||
| if (TrackableWrapperUtils.is_function(value)) | |||||
| { | |||||
| Debug.Assert(key is string); | |||||
| children[key as string] = value; | |||||
| } | |||||
| } | |||||
| } | |||||
| return children; | |||||
| } | |||||
| protected Trackable _track_value(Trackable value, object name) | |||||
| { | |||||
| bool string_key = name is string; | |||||
| if (!string_key) | |||||
| { | |||||
| name = "-non_string_key"; | |||||
| } | |||||
| try | |||||
| { | |||||
| bool no_dependency = value is NoDependency; | |||||
| value = base._track_value(value, name as string); | |||||
| if(!(string_key || no_dependency)) | |||||
| { | |||||
| _non_string_key = true; | |||||
| } | |||||
| return value; | |||||
| } | |||||
| catch (ValueError) | |||||
| { | |||||
| return (Trackable)sticky_attribute_assignment(this, name as string, value); | |||||
| } | |||||
| } | |||||
| private bool Dirty => _external_modification || _non_string_key; | |||||
| private void _check_self_external_modification() | |||||
| { | |||||
| if (Dirty) | |||||
| { | |||||
| return; | |||||
| } | |||||
| if(!this._storage.SequenceEqual(_last_wrapped_dict_snapshot)) | |||||
| { | |||||
| _external_modification = true; | |||||
| _last_wrapped_dict_snapshot = null; | |||||
| } | |||||
| } | |||||
| private void _update_snapshot() | |||||
| { | |||||
| // TODO(Rinne): deal with attribute_sentinel. | |||||
| if (Dirty) return; | |||||
| _last_wrapped_dict_snapshot = new Dictionary<object, Trackable>(_storage); | |||||
| } | |||||
| } | |||||
| public class ListWrapper : TrackableDataStructure, IList<Trackable>, ICloneable, ITrackableWrapper | |||||
| { | { | ||||
| private IList<Trackable> _storage; | private IList<Trackable> _storage; | ||||
| private bool _non_append_mutation_value; | private bool _non_append_mutation_value; | ||||
| @@ -198,11 +471,51 @@ namespace Tensorflow.Training | |||||
| /// modified directly after constructing the `ListWrapper`, and if changes are detected the `ListWrapper` will throw an exception on save.</param> | /// modified directly after constructing the `ListWrapper`, and if changes are detected the `ListWrapper` will throw an exception on save.</param> | ||||
| public ListWrapper(IList<Trackable> wrapped_list) | public ListWrapper(IList<Trackable> wrapped_list) | ||||
| { | { | ||||
| _storage = wrapped_list; | |||||
| _storage = new List<Trackable>(wrapped_list); | |||||
| _non_append_mutation_value = _external_modification_value = false; | _non_append_mutation_value = _external_modification_value = false; | ||||
| _last_wrapped_list_snapshot = new List<Trackable>(_storage); | _last_wrapped_list_snapshot = new List<Trackable>(_storage); | ||||
| } | } | ||||
| public string Identifier => "trackable_list_wrapper"; | |||||
| public int Version => 1; | |||||
| public int MinConsumerVersion => 1; | |||||
| public int MinProducerVersion => 1; | |||||
| public Trackable FromProto(SavedUserObject proto) | |||||
| { | |||||
| if(TrackableWrapperUtils.ShouldLoad(this, proto)) | |||||
| { | |||||
| return new ListWrapper(new Trackable[] { }); | |||||
| } | |||||
| else | |||||
| { | |||||
| return null; | |||||
| } | |||||
| } | |||||
| public void SetValue(object name, object value) | |||||
| { | |||||
| Debug.Assert(name is string); | |||||
| if(int.TryParse(name as string, out var index)) | |||||
| { | |||||
| if(value is not Trackable trackable) | |||||
| { | |||||
| throw new TypeError("Cannot set an object which is not trackable to ListWrapper."); | |||||
| } | |||||
| if(Count <= index) | |||||
| { | |||||
| Add(trackable); | |||||
| } | |||||
| else | |||||
| { | |||||
| this[index] = trackable; | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new NotImplementedException("Encounter an unexpected behavior in <ListWrapper.SetAttr>, please " + | |||||
| "submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||||
| } | |||||
| } | |||||
| protected bool NonAppendMuation { | protected bool NonAppendMuation { | ||||
| get => _non_append_mutation_value; | get => _non_append_mutation_value; | ||||
| set | set | ||||
| @@ -222,7 +535,7 @@ namespace Tensorflow.Training | |||||
| } | } | ||||
| } | } | ||||
| public override IEnumerable<Trackable> Values => this; | |||||
| public override ICollection<Trackable> Values => this; | |||||
| public bool IsReadOnly { get => _storage.IsReadOnly; } | public bool IsReadOnly { get => _storage.IsReadOnly; } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -239,7 +552,7 @@ namespace Tensorflow.Training | |||||
| private void update_snapshot() | private void update_snapshot() | ||||
| { | { | ||||
| // TODO: deal with `attribute_sentinel`. | |||||
| // TODO(Rinne): deal with `attribute_sentinel`. | |||||
| if (_external_modification_value || _non_append_mutation_value) return; | if (_external_modification_value || _non_append_mutation_value) return; | ||||
| _last_wrapped_list_snapshot = new List<Trackable>(_storage); | _last_wrapped_list_snapshot = new List<Trackable>(_storage); | ||||
| } | } | ||||
| @@ -286,9 +599,9 @@ namespace Tensorflow.Training | |||||
| { | { | ||||
| base._track_value(value, name); | base._track_value(value, name); | ||||
| } | } | ||||
| catch(ValueError ex) | |||||
| catch(ValueError) | |||||
| { | { | ||||
| value = sticky_attribute_assignment(this, name, value); | |||||
| value = (Trackable)sticky_attribute_assignment(this, name, value); | |||||
| } | } | ||||
| return value; | return value; | ||||
| } | } | ||||
| @@ -343,7 +656,11 @@ namespace Tensorflow.Training | |||||
| update_snapshot(); | update_snapshot(); | ||||
| } | } | ||||
| public void Clear() => _storage.Clear(); | |||||
| public void Clear() | |||||
| { | |||||
| _storage.Clear(); | |||||
| update_snapshot(); | |||||
| } | |||||
| public bool Contains(Trackable item) => _storage.Contains(item); | public bool Contains(Trackable item) => _storage.Contains(item); | ||||
| @@ -519,6 +519,14 @@ namespace Tensorflow.Util | |||||
| return pack_sequence_as(structure, mapped_flat_structure) as Tensor; | return pack_sequence_as(structure, mapped_flat_structure) as Tensor; | ||||
| } | } | ||||
| public static T2 map_structure<T1, T2>(Func<T1, T2> func, T1 structure) where T2: class | |||||
| { | |||||
| var flat_structure = flatten(structure); | |||||
| var mapped_flat_structure = flat_structure.Select(func).Select(x => (object)x); | |||||
| return pack_sequence_as(structure, mapped_flat_structure) as T2; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Same as map_structure, but with only one structure (no combining of multiple structures) | /// Same as map_structure, but with only one structure (no combining of multiple structures) | ||||
| /// </summary> | /// </summary> | ||||
| @@ -97,7 +97,7 @@ namespace Tensorflow | |||||
| else | else | ||||
| { | { | ||||
| unique_id = $"{handle_name}_{ops.uid()}"; | unique_id = $"{handle_name}_{ops.uid()}"; | ||||
| shared_name = tf.Context.shared_name(); | |||||
| shared_name = null; | |||||
| } | } | ||||
| var attr = new AttrValue(); | var attr = new AttrValue(); | ||||
| @@ -60,7 +60,15 @@ namespace Tensorflow.Keras | |||||
| public void track_variable(IVariableV1 v) | public void track_variable(IVariableV1 v) | ||||
| { | { | ||||
| if (tf.Context.executing_eagerly()) | |||||
| { | |||||
| return; | |||||
| } | |||||
| var graph = v.Graph; | var graph = v.Graph; | ||||
| if(graph is null) | |||||
| { | |||||
| graph = get_graph(); | |||||
| } | |||||
| _GRAPH_VARIABLES[graph.graph_key] = v; | _GRAPH_VARIABLES[graph.graph_key] = v; | ||||
| } | } | ||||
| @@ -21,10 +21,13 @@ using System.Linq; | |||||
| using System.Threading; | using System.Threading; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Metrics; | |||||
| using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
| using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
| using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
| using Tensorflow.Train; | using Tensorflow.Train; | ||||
| using Tensorflow.Training; | |||||
| using Tensorflow.Util; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
| @@ -349,5 +352,59 @@ namespace Tensorflow.Keras.Engine | |||||
| { | { | ||||
| } | } | ||||
| public override void SetAttr(string name, object value) | |||||
| { | |||||
| // TODO(Rinne): deal with "_self_setattr_tracking". | |||||
| value = TrackableDataStructure.sticky_attribute_assignment(this, name, value); | |||||
| foreach(var val in nest.flatten(value)) | |||||
| { | |||||
| if(val is Metric) | |||||
| { | |||||
| // TODO(Rinne): deal with metrics. | |||||
| } | |||||
| } | |||||
| // TODO(Rinne): deal with "_auto_track_sub_layers". | |||||
| foreach(var val in nest.flatten(value)) | |||||
| { | |||||
| if(val is not IVariableV1 variable) | |||||
| { | |||||
| continue; | |||||
| } | |||||
| if (variable.Trainable) | |||||
| { | |||||
| if (_trainable_weights.Contains(variable)) | |||||
| { | |||||
| continue; | |||||
| } | |||||
| _trainable_weights.Add(variable); | |||||
| } | |||||
| else | |||||
| { | |||||
| if (_non_trainable_weights.Contains(variable)) | |||||
| { | |||||
| continue; | |||||
| } | |||||
| _non_trainable_weights.Add(variable); | |||||
| } | |||||
| keras.backend.track_variable(variable); | |||||
| } | |||||
| // Directly use the implementation of `Trackable`. | |||||
| var t = this.GetType(); | |||||
| var field_info = t.GetField(name); | |||||
| if (field_info is not null) | |||||
| { | |||||
| field_info.SetValue(this, value); | |||||
| } | |||||
| else | |||||
| { | |||||
| CustomizedFields[name] = value; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,7 +1,12 @@ | |||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using System.Diagnostics; | |||||
| using Tensorflow.Framework.Models; | |||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.Losses; | using Tensorflow.Keras.Losses; | ||||
| using Tensorflow.Keras.Saving; | |||||
| using Tensorflow.Keras.Saving.SavedModel; | using Tensorflow.Keras.Saving.SavedModel; | ||||
| using Tensorflow.Keras.Utils; | |||||
| using Tensorflow.Train; | using Tensorflow.Train; | ||||
| using Tensorflow.Util; | |||||
| namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
| { | { | ||||
| @@ -22,14 +27,16 @@ namespace Tensorflow.Keras.Engine | |||||
| IOptimizer optimizer; | IOptimizer optimizer; | ||||
| IVariableV1 _steps_per_execution; | IVariableV1 _steps_per_execution; | ||||
| protected bool _is_graph_network; | protected bool _is_graph_network; | ||||
| protected Tensors inputs; | |||||
| public Tensors inputs; | |||||
| protected Tensors outputs; | protected Tensors outputs; | ||||
| protected List<string> input_names; | |||||
| public string[] output_names; | public string[] output_names; | ||||
| IVariableV1 _train_counter; | IVariableV1 _train_counter; | ||||
| IVariableV1 _test_counter; | IVariableV1 _test_counter; | ||||
| IVariableV1 _predict_counter; | IVariableV1 _predict_counter; | ||||
| bool _base_model_initialized; | bool _base_model_initialized; | ||||
| bool stop_training; | bool stop_training; | ||||
| TensorSpec _saved_model_inputs_spec; | |||||
| public bool IsGraphNetwork => _is_graph_network; | public bool IsGraphNetwork => _is_graph_network; | ||||
| @@ -45,6 +52,38 @@ namespace Tensorflow.Keras.Engine | |||||
| _init_batch_counters(); | _init_batch_counters(); | ||||
| } | } | ||||
| public void _set_inputs(TensorSpec inputs) | |||||
| { | |||||
| _set_save_spec(inputs); | |||||
| } | |||||
| internal void _set_save_spec(TensorSpec inputs) | |||||
| { | |||||
| if(_saved_model_inputs_spec is not null) | |||||
| { | |||||
| return; | |||||
| } | |||||
| var input_names = this.input_names; | |||||
| if(input_names is null || input_names.Count == 0) | |||||
| { | |||||
| input_names = compile_utils.create_pseudo_input_names(inputs); | |||||
| } | |||||
| var flat_inputs = nest.flatten(inputs); | |||||
| List<TensorSpec> specs = new(); | |||||
| foreach(var (name, tensor) in zip(input_names, flat_inputs)) | |||||
| { | |||||
| specs.Add(tf_utils.get_tensor_spec(tensor, dynamic_batch: false, name: name)); | |||||
| } | |||||
| var packed_specs = nest.pack_sequence_as(inputs, specs) as TensorSpec; | |||||
| Debug.Assert(specs is not null); | |||||
| _saved_model_inputs_spec = packed_specs; | |||||
| if(this is Sequential && _buildInputShape is null) | |||||
| { | |||||
| _buildInputShape = nest.map_structure<TensorSpec, TensorShapeConfig>(x => x is null ? null : x.shape, packed_specs); | |||||
| } | |||||
| } | |||||
| internal override void Initialize(LayerArgs args) | internal override void Initialize(LayerArgs args) | ||||
| { | { | ||||
| _init_batch_counters(); | _init_batch_counters(); | ||||
| @@ -145,6 +184,16 @@ namespace Tensorflow.Keras.Engine | |||||
| return children; | return children; | ||||
| } | } | ||||
| public override void SetAttr(string name, object value) | |||||
| { | |||||
| // TODO(Rinne): deal with "_self_setattr_tracking". | |||||
| //if(nest.flatten(value).All(v => v is Layer or IVariableV1 || base_layer_utils.has_weights(v))) | |||||
| //{ | |||||
| // this._base_model_initialized; | |||||
| //} | |||||
| base.SetAttr(name, value); | |||||
| } | |||||
| void IModel.set_stopTraining_true() | void IModel.set_stopTraining_true() | ||||
| { | { | ||||
| @@ -1,12 +1,14 @@ | |||||
| using Newtonsoft.Json; | using Newtonsoft.Json; | ||||
| using Newtonsoft.Json.Linq; | using Newtonsoft.Json.Linq; | ||||
| using System; | using System; | ||||
| using System.Collections; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.ComponentModel; | using System.ComponentModel; | ||||
| using System.Diagnostics; | using System.Diagnostics; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Reflection; | using System.Reflection; | ||||
| using System.Text.RegularExpressions; | using System.Text.RegularExpressions; | ||||
| using Tensorflow.Framework.Models; | |||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Layers; | using Tensorflow.Keras.Layers; | ||||
| @@ -17,6 +19,8 @@ using Tensorflow.Keras.Saving.SavedModel; | |||||
| using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
| using Tensorflow.Train; | using Tensorflow.Train; | ||||
| using Tensorflow.Training; | using Tensorflow.Training; | ||||
| using Tensorflow.Training.Saving.SavedModel; | |||||
| using Tensorflow.Util; | |||||
| using ThirdParty.Tensorflow.Python.Keras.Protobuf; | using ThirdParty.Tensorflow.Python.Keras.Protobuf; | ||||
| using static Tensorflow.ApiDef.Types; | using static Tensorflow.ApiDef.Types; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -190,12 +194,13 @@ namespace Tensorflow.Keras.Saving | |||||
| Name = config["name"].ToObject<string>() | Name = config["name"].ToObject<string>() | ||||
| }); | }); | ||||
| //s.Name = config["name"].ToObject<string>(); | //s.Name = config["name"].ToObject<string>(); | ||||
| if(s.input is null || s.input.Length == 0) | |||||
| if(s.inputs is null || s.inputs.Length == 0) | |||||
| { | { | ||||
| var first_layer = _get_child_layer_node_ids(model_id)[0]; | var first_layer = _get_child_layer_node_ids(model_id)[0]; | ||||
| var input_specs = _infer_inputs(first_layer); | var input_specs = _infer_inputs(first_layer); | ||||
| var input_shapes = _infer_inputs(first_layer, true); | |||||
| var input_shapes = _infer_input_shapes(first_layer); | |||||
| // `model._set_inputs(input_specs)` | // `model._set_inputs(input_specs)` | ||||
| s._set_inputs(input_specs); | |||||
| // skip the check of input_specs is Dictionary | // skip the check of input_specs is Dictionary | ||||
| if (!s.Built) | if (!s.Built) | ||||
| @@ -220,12 +225,12 @@ namespace Tensorflow.Keras.Saving | |||||
| private void _set_network_attributes_from_metadata(Model revived_object) | private void _set_network_attributes_from_metadata(Model revived_object) | ||||
| { | { | ||||
| var metadata = revived_object.SerializedAttributes["matadata"] as JObject; | |||||
| if (metadata.ContainsKey("dtype")) | |||||
| var metadata = revived_object.SerializedAttributes["metadata"] as KerasMetaData; | |||||
| if (metadata.DType != TF_DataType.DtInvalid) | |||||
| { | { | ||||
| // TODO(Rinne): set_dtype_policy. | // TODO(Rinne): set_dtype_policy. | ||||
| } | } | ||||
| revived_object.args.Trainable = metadata["trainable"].Value<bool>(); | |||||
| revived_object.args.Trainable = metadata.Trainable; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -305,6 +310,11 @@ namespace Tensorflow.Keras.Saving | |||||
| private (Trackable, Action<object, object, object>) _load_layer(int node_id, string identifier, string metadata_json) | private (Trackable, Action<object, object, object>) _load_layer(int node_id, string identifier, string metadata_json) | ||||
| { | { | ||||
| var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); | var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); | ||||
| // Debug(Rinne) | |||||
| if(node_id == 11) | |||||
| { | |||||
| Console.WriteLine(); | |||||
| } | |||||
| if (loaded_nodes.ContainsKey(node_id)) | if (loaded_nodes.ContainsKey(node_id)) | ||||
| { | { | ||||
| @@ -472,15 +482,7 @@ namespace Tensorflow.Keras.Saving | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| var properties = layer.GetType().GetProperties(); | |||||
| foreach(var p in properties) | |||||
| { | |||||
| if(p.Name == name as string && p.GetValue(layer) is not null) | |||||
| { | |||||
| return; | |||||
| } | |||||
| } | |||||
| Loader.setattr(layer, name, value); | |||||
| layer.SetAttr(name as string, value); | |||||
| } | } | ||||
| } | } | ||||
| @@ -607,7 +609,7 @@ namespace Tensorflow.Keras.Saving | |||||
| if(build_input_shape is null) | if(build_input_shape is null) | ||||
| { | { | ||||
| build_input_shape = _infer_inputs(node_id, convert_to_shapes: true); | |||||
| build_input_shape = _infer_input_shapes(node_id); | |||||
| } | } | ||||
| if(build_input_shape is not null) | if(build_input_shape is not null) | ||||
| @@ -633,7 +635,7 @@ namespace Tensorflow.Keras.Saving | |||||
| /// <param name="layer_node_id"></param> | /// <param name="layer_node_id"></param> | ||||
| /// <param name="convert_to_shapes"></param> | /// <param name="convert_to_shapes"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| private Shape _infer_inputs(int layer_node_id, bool convert_to_shapes = false) | |||||
| private TensorSpec _infer_inputs(int layer_node_id) | |||||
| { | { | ||||
| var call_fn_id = _search_for_child_node(layer_node_id, new string[] { "call_and_return_all_conditional_losses" }); | var call_fn_id = _search_for_child_node(layer_node_id, new string[] { "call_and_return_all_conditional_losses" }); | ||||
| if(call_fn_id is null) | if(call_fn_id is null) | ||||
| @@ -648,7 +650,22 @@ namespace Tensorflow.Keras.Saving | |||||
| } | } | ||||
| var call_fn_name = concrete_functions[0]; | var call_fn_name = concrete_functions[0]; | ||||
| var call_fn_proto = _proto.ConcreteFunctions[call_fn_name]; | var call_fn_proto = _proto.ConcreteFunctions[call_fn_name]; | ||||
| throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||||
| var structured_input_signature = nested_structure_coder.decode_proto(call_fn_proto.CanonicalizedInputSignature); | |||||
| Debug.Assert(structured_input_signature is IEnumerable); | |||||
| var first_enumerator = (structured_input_signature as IEnumerable).GetEnumerator(); | |||||
| first_enumerator.MoveNext(); | |||||
| var first = first_enumerator.Current; | |||||
| Debug.Assert(first is IEnumerable); | |||||
| var inputs_enumerator = (first as IEnumerable).GetEnumerator(); | |||||
| inputs_enumerator.MoveNext(); | |||||
| var inputs = inputs_enumerator.Current as TensorSpec; | |||||
| return inputs; | |||||
| } | |||||
| private Shape _infer_input_shapes(int layer_node_id) | |||||
| { | |||||
| var inputs = _infer_inputs(layer_node_id); | |||||
| return nest.map_structure(x => x.shape, inputs); | |||||
| } | } | ||||
| private int? _search_for_child_node(int parent_id, IEnumerable<string> path_to_child) | private int? _search_for_child_node(int parent_id, IEnumerable<string> path_to_child) | ||||
| @@ -48,19 +48,7 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| var properties = layer.GetType().GetProperties(); | |||||
| foreach (var p in properties) | |||||
| { | |||||
| if ((string)name == p.Name) | |||||
| { | |||||
| if(p.GetValue(layer) is not null) | |||||
| { | |||||
| return; | |||||
| } | |||||
| p.SetValue(layer, value); | |||||
| return; | |||||
| } | |||||
| } | |||||
| layer.SetAttr(name as string, value); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -11,7 +11,7 @@ using Tensorflow.Keras.Optimizers; | |||||
| using ThirdParty.Tensorflow.Python.Keras.Protobuf; | using ThirdParty.Tensorflow.Python.Keras.Protobuf; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using Tensorflow.Training; | using Tensorflow.Training; | ||||
| using System.Diagnostics; | |||||
| namespace Tensorflow.Keras.Saving.SavedModel; | namespace Tensorflow.Keras.Saving.SavedModel; | ||||
| @@ -135,12 +135,17 @@ public partial class KerasSavedModelUtils | |||||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | if (x is ResourceVariable or RefVariable) return (Trackable)x; | ||||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | ||||
| })); | })); | ||||
| var layers = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); | |||||
| Dictionary<string, Trackable> res = new(); | Dictionary<string, Trackable> res = new(); | ||||
| res["variables"] = variables; | |||||
| res["trainable_variables"] = trainable_variables; | |||||
| res["non_trainable_variables"] = non_trainable_variables; | |||||
| res["layers"] = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); | |||||
| Debug.Assert(variables is Trackable); | |||||
| Debug.Assert(trainable_variables is Trackable); | |||||
| Debug.Assert(non_trainable_variables is Trackable); | |||||
| Debug.Assert(layers is Trackable); | |||||
| res["variables"] = variables as Trackable; | |||||
| res["trainable_variables"] = trainable_variables as Trackable; | |||||
| res["non_trainable_variables"] = non_trainable_variables as Trackable; | |||||
| res["layers"] = layers as Trackable; | |||||
| return res; | return res; | ||||
| } | } | ||||
| @@ -165,6 +165,14 @@ namespace Tensorflow.Keras.Utils | |||||
| } | } | ||||
| } | } | ||||
| public static bool has_weights(object obj) | |||||
| { | |||||
| var obj_type = obj.GetType(); | |||||
| return obj_type.GetField("trainable_weights") is not null && | |||||
| obj_type.GetField("non_trainable_weights") is not null && | |||||
| obj is not Type; | |||||
| } | |||||
| // recusive | // recusive | ||||
| static bool uses_keras_history(Tensor op_input) | static bool uses_keras_history(Tensor op_input) | ||||
| { | { | ||||
| @@ -0,0 +1,22 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Framework.Models; | |||||
| using Tensorflow.Util; | |||||
| namespace Tensorflow.Keras.Utils | |||||
| { | |||||
| internal static class compile_utils | |||||
| { | |||||
| public static List<string> create_pseudo_input_names(TensorSpec inputs) | |||||
| { | |||||
| return _create_pseudo_names(inputs, "input_"); | |||||
| } | |||||
| private static List<string> _create_pseudo_names(TensorSpec tensors, string prefix) | |||||
| { | |||||
| // TODO(Rinne): align with tensorflow | |||||
| return new List<string>() { $"{prefix}1" }; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -17,6 +17,7 @@ | |||||
| using System; | using System; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Framework; | using Tensorflow.Framework; | ||||
| using Tensorflow.Framework.Models; | |||||
| namespace Tensorflow.Keras.Utils | namespace Tensorflow.Keras.Utils | ||||
| { | { | ||||
| @@ -69,5 +70,29 @@ namespace Tensorflow.Keras.Utils | |||||
| false_fn: false_fn, | false_fn: false_fn, | ||||
| name: name); | name: name); | ||||
| } | } | ||||
| public static TensorSpec get_tensor_spec(Tensor t, bool dynamic_batch = false, string name = null) | |||||
| { | |||||
| throw new NotImplementedException("The function is waited to be implemented in the future."); | |||||
| } | |||||
| public static TensorSpec get_tensor_spec(TensorSpec t, bool dynamic_batch = false, string name = null) | |||||
| { | |||||
| var spec = t; | |||||
| if (!dynamic_batch) | |||||
| { | |||||
| return spec; | |||||
| } | |||||
| var dynamic_batch_spec = new TensorSpec(t.shape, t.dtype, t.name); | |||||
| var shape = dynamic_batch_spec.shape; | |||||
| if(shape.rank > 0) | |||||
| { | |||||
| var shape_list = shape.as_int_list(); | |||||
| // TODO(Rinne): check if -1 is equivalent to None in python. | |||||
| shape_list[0] = -1; | |||||
| dynamic_batch_spec.shape = new Shape(shape_list); | |||||
| } | |||||
| return dynamic_batch_spec; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -64,5 +64,8 @@ public class SequentialModelLoad | |||||
| { | { | ||||
| var model = tf.keras.models.load_model(@"C:\Work\tf.net\tf_test\python_func"); | var model = tf.keras.models.load_model(@"C:\Work\tf.net\tf_test\python_func"); | ||||
| model.summary(); | model.summary(); | ||||
| var x = tf.ones((2, 10)); | |||||
| var y = model.Apply(x); | |||||
| } | } | ||||
| } | } | ||||