* Add check for dims of x and y in model.fit. * Init the serialization of keras pb model. * Add more facilities to the saved model framework. * Add ListWrapper and ITrackable, and revise implmentations. * Add serialized attributes. * Implement layer serializations. * Add lacked implementations (mainly MultiDeviceSaver). * Support autograph.to_graph under graph mode. * Add more implementations to the pb model save. * Add more implementations to the keras part of pb model save. * Refine some code after merge. * Add two simple sequential test case of pb model save. * Implement serializing attributes other keras arg definitions. * Add alexnet pb save test. * Check and refine the code. --------- Co-authored-by: AsakusaRinne <AsakusaRinne@gmail.com>tags/v0.100.4-load-saved-model
| @@ -14,6 +14,8 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Text; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public partial class tensorflow | public partial class tensorflow | ||||
| @@ -23,6 +25,26 @@ namespace Tensorflow | |||||
| public class CompatApi | public class CompatApi | ||||
| { | { | ||||
| public CompatV1Api v1 { get; } = new CompatV1Api(); | public CompatV1Api v1 { get; } = new CompatV1Api(); | ||||
| internal string as_text(string bytes_or_text, Encoding? encoding = null) | |||||
| { | |||||
| if(encoding is null) encoding = Encoding.UTF8; | |||||
| return bytes_or_text; | |||||
| } | |||||
| internal string as_text(byte[] bytes_or_text, Encoding? encoding = null) | |||||
| { | |||||
| if(encoding is null) encoding = Encoding.UTF8; | |||||
| return encoding.GetString(bytes_or_text); | |||||
| } | |||||
| internal string as_str(string bytes_or_text, Encoding? encoding = null) | |||||
| { | |||||
| return as_text(bytes_or_text, encoding); | |||||
| } | |||||
| internal string as_str(byte[] bytes_or_text, Encoding? encoding = null) | |||||
| { | |||||
| return as_text(bytes_or_text, encoding); | |||||
| } | |||||
| } | } | ||||
| public bool executing_eagerly() | public bool executing_eagerly() | ||||
| @@ -0,0 +1,152 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Diagnostics; | |||||
| using System.IO; | |||||
| using System.Linq; | |||||
| using Tensorflow.Train; | |||||
| using Tensorflow.Training; | |||||
| using pbc = global::Google.Protobuf.Collections; | |||||
| namespace Tensorflow.Checkpoint; | |||||
| public static class CheckPointUtils | |||||
| { | |||||
| private static string _ESCAPE_CHAR = "."; | |||||
| public static (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>, IDictionary<Trackable, int>, | |||||
| IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>, | |||||
| IDictionary<Trackable, string>) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view) | |||||
| { | |||||
| var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); | |||||
| Dictionary<Trackable, string> object_names = new(); | |||||
| foreach (var pair in node_paths) | |||||
| { | |||||
| object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value); | |||||
| } | |||||
| Dictionary<Trackable, int> node_ids = new(); | |||||
| for (int i = 0; i < trackable_objects.Count; i++) | |||||
| { | |||||
| node_ids[trackable_objects[i]] = i; | |||||
| } | |||||
| var slot_variables = serialize_slot_variables(trackable_objects, node_ids, object_names); | |||||
| return (trackable_objects, node_paths, node_ids, slot_variables, object_names); | |||||
| } | |||||
| public static | |||||
| IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>> | |||||
| serialize_slot_variables(IEnumerable<Trackable> trackable_objects, | |||||
| IDictionary<Trackable, int> node_ids, IDictionary<Trackable, string> object_names) | |||||
| { | |||||
| var non_slot_objects = trackable_objects.ToList(); | |||||
| Dictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>> | |||||
| slot_variables = new(); | |||||
| foreach (var trackable in non_slot_objects) | |||||
| { | |||||
| if (trackable is not Optimizer) | |||||
| { | |||||
| continue; | |||||
| } | |||||
| var optim = (Optimizer)trackable; | |||||
| var slot_names = optim.get_slot_names(); | |||||
| foreach (var slot_name in slot_names) | |||||
| { | |||||
| for (int original_variable_node_id = 0; | |||||
| original_variable_node_id < non_slot_objects.Count; | |||||
| original_variable_node_id++) | |||||
| { | |||||
| var original_variable = non_slot_objects[original_variable_node_id]; | |||||
| IVariableV1 slot_variable; | |||||
| if (original_variable is not IVariableV1) | |||||
| { | |||||
| slot_variable = null; | |||||
| } | |||||
| slot_variable = optim.get_slot((IVariableV1)original_variable, slot_name); | |||||
| if(slot_variable is null) continue; | |||||
| // There're some problems about the inherits of `Variable` and `Trackable`. | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | |||||
| } | |||||
| return slot_variables; | |||||
| } | |||||
| public static Trackable get_mapped_trackable(Trackable trackable, IDictionary<Trackable, Trackable>? object_map) | |||||
| { | |||||
| if (object_map is null || !object_map.TryGetValue(trackable, out var possible_res)) | |||||
| { | |||||
| return trackable; | |||||
| } | |||||
| else | |||||
| { | |||||
| return possible_res; | |||||
| } | |||||
| } | |||||
| public static string get_full_name(Trackable variable) | |||||
| { | |||||
| // TODO: This state is not correct, the whole framework need to be updated in the future. | |||||
| if (!(variable is IVariableV1 || resource_variable_ops.is_resource_variable(variable))) | |||||
| { | |||||
| return ""; | |||||
| } | |||||
| // skip the check of attribute `_save_slice_info` . | |||||
| // TODO: Need to be revised!!! | |||||
| Debug.Assert(variable is BaseResourceVariable); | |||||
| return ((BaseResourceVariable)variable).Name; | |||||
| } | |||||
| public static void add_checkpoint_values_check(TrackableObjectGraph object_graph_proto) | |||||
| { | |||||
| HashSet<int> checkpointed_trackables = new(); | |||||
| Dictionary<int, HashSet<int>> parents = new(); | |||||
| for (int i = 0; i < object_graph_proto.Nodes.Count; i++) | |||||
| { | |||||
| var object_proto = object_graph_proto.Nodes[i]; | |||||
| // skip the process of registered saver. | |||||
| if (object_proto.Attributes is not null && object_proto.Attributes.Count > 0 || | |||||
| object_proto.SlotVariables is not null && object_proto.SlotVariables.Count > 0) | |||||
| { | |||||
| checkpointed_trackables.Add(i); | |||||
| } | |||||
| foreach (var child_proto in object_proto.Children) | |||||
| { | |||||
| var child = child_proto.NodeId; | |||||
| if (!parents.ContainsKey(child)) | |||||
| { | |||||
| parents[child] = new HashSet<int>(); | |||||
| } | |||||
| parents[child].Add(i); | |||||
| } | |||||
| } | |||||
| Queue<int> to_visit = new(checkpointed_trackables.AsEnumerable()); | |||||
| while (to_visit.Count > 0) | |||||
| { | |||||
| var trackable = to_visit.Dequeue(); | |||||
| if (!parents.ContainsKey(trackable)) continue; | |||||
| var current_parents = parents[trackable]; | |||||
| foreach (var parent in current_parents) | |||||
| { | |||||
| checkpointed_trackables.Add(parent); | |||||
| if (parents.ContainsKey(parent)) | |||||
| { | |||||
| to_visit.Enqueue(parent); | |||||
| } | |||||
| } | |||||
| parents.Remove(trackable); | |||||
| } | |||||
| // TODO: Complete it after supporting checkpoint. | |||||
| // for (int i = 0; i < object_graph_proto.Nodes.Count; i++) | |||||
| // { | |||||
| // object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i); | |||||
| // } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,5 @@ | |||||
| namespace Tensorflow.Checkpoint; | |||||
| public record class CheckpointOptions( | |||||
| string? experimental_io_device = null, | |||||
| bool experimental_enable_async_checkpoint = false); | |||||
| @@ -0,0 +1,64 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using Serilog.Debugging; | |||||
| using Tensorflow.Keras.Saving.SavedModel; | |||||
| using Tensorflow.Train; | |||||
| namespace Tensorflow.Checkpoint; | |||||
| public class ObjectGraphView: TrackableView, ICloneable | |||||
| { | |||||
| protected IEnumerable<TrackableReference>? _attached_dependencies; | |||||
| // TODO: attached_dependencies | |||||
| public ObjectGraphView(Trackable root, IEnumerable<TrackableReference>? attached_dependencies = null): base(root) | |||||
| { | |||||
| _attached_dependencies = attached_dependencies; | |||||
| } | |||||
| public object Clone() | |||||
| { | |||||
| // TODO: Implement real deep copy corresponding to tensorflow/python/checkpoint/graph_view.ObjectGraphView.__deepcopy__ | |||||
| return new ObjectGraphView(Root, _attached_dependencies); | |||||
| } | |||||
| public virtual List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? serialization_cache = null) | |||||
| { | |||||
| List<TrackableReference> res = base.children(obj, save_type, serialization_cache) | |||||
| .Select(x => new TrackableReference(x.Key, x.Value)).ToList(); | |||||
| // Check the reference, not value. | |||||
| if (obj == Root && _attached_dependencies is not null) | |||||
| { | |||||
| res.AddRange(_attached_dependencies); | |||||
| } | |||||
| return res; | |||||
| } | |||||
| public override IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? serialization_cache = null) | |||||
| { | |||||
| return list_children(obj, save_type, serialization_cache).ToDictionary(x => x.Name, x => x.Refer); | |||||
| } | |||||
| public IEnumerable<TrackableReference>? AttachedDependencies | |||||
| { | |||||
| get => _attached_dependencies; | |||||
| } | |||||
| public virtual (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal() | |||||
| { | |||||
| return base._descendants_with_paths(); | |||||
| } | |||||
| // TODO: complete the implementation | |||||
| public void serialize_object_graph(object? saveables_cache = null) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| // TODO: complete the implementation | |||||
| public void frozen_saveable_objects(object? object_map = null, object? to_graph = null, object call_with_mapped_captures = null) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,255 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Diagnostics; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Tensorflow.Train; | |||||
| using Tensorflow.Training; | |||||
| using pbc = global::Google.Protobuf.Collections; | |||||
| namespace Tensorflow.Checkpoint | |||||
| { | |||||
| internal record class TrackableData( | |||||
| // A trackable in the root Trackable object graph. | |||||
| Trackable trackable, | |||||
| // The index at which the Trackable appears in TrackableObjectGraph.nodes. | |||||
| int node_id, | |||||
| // The BFS-generated path from the root object / used to generate readable checkpoint keys. | |||||
| string object_name, | |||||
| // A list of ObjectReference for each child connected to this Trackable. | |||||
| pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_proto, | |||||
| // A list of SlotVariableReference to save to the object (only valid for Optimizer objects). | |||||
| pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference> slot_variable_proto, | |||||
| // The object to save to checkpoint. Usually this is the same as `trackable`, | |||||
| // but can differ when the the caller wants to specify a different object to | |||||
| // save. For example, when saving checkpoints asynchronously, variables are | |||||
| // copied to the CPU. `object_to_save` is set as the copied variable. | |||||
| Trackable object_to_save | |||||
| ); | |||||
| public static class SaveUtil | |||||
| { | |||||
| public static (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||||
| serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null) | |||||
| { | |||||
| var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); | |||||
| var (tensor_trackables, pystate_trackables, registered_trackables) = split_trackables(trackable_data); | |||||
| var object_graph_proto = fill_object_graph_proto(trackable_data); | |||||
| var serialized_tensors = get_and_write_tensors_to_serialize(tensor_trackables, node_ids, call_with_mapped_captures, cache, object_graph_proto); | |||||
| var registered_savers = get_and_write_registered_savers(registered_trackables, object_graph_proto); | |||||
| Dictionary<Tensor, object> feed_additions; | |||||
| if(cache is null) | |||||
| { | |||||
| feed_additions = null; | |||||
| serialized_tensors = serialized_tensors.Concat(get_and_write_tensors_to_serialize(pystate_trackables, node_ids, call_with_mapped_captures, | |||||
| cache, object_graph_proto)).ToDictionary(x => x.Key, x => x.Value); | |||||
| } | |||||
| else | |||||
| { | |||||
| feed_additions = null; | |||||
| // TODO: deal with cache. | |||||
| throw new NotFiniteNumberException(); | |||||
| } | |||||
| CheckPointUtils.add_checkpoint_values_check(object_graph_proto); | |||||
| return (serialized_tensors, feed_additions, registered_savers, object_graph_proto); | |||||
| } | |||||
| private static (IList<TrackableData>, IDictionary<Trackable, int>) gather_trackable_data(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map) | |||||
| { | |||||
| var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); | |||||
| Dictionary<Trackable, string> object_names = new(); | |||||
| foreach(var pair in node_paths) | |||||
| { | |||||
| object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value); | |||||
| } | |||||
| Dictionary<Trackable, int> node_ids = new(); | |||||
| for(int i = 0; i < trackable_objects.Count; i++) | |||||
| { | |||||
| node_ids[trackable_objects[i]] = i; | |||||
| } | |||||
| var slot_variables = CheckPointUtils.serialize_slot_variables(trackable_objects, node_ids, object_names); | |||||
| List<TrackableData> trackable_data = new(); | |||||
| foreach(var trackable in trackable_objects) | |||||
| { | |||||
| pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_proto = new(); | |||||
| foreach(var child in graph_view.list_children(trackable)) | |||||
| { | |||||
| children_proto.Add(new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference() | |||||
| { | |||||
| NodeId = node_ids[child.Refer], | |||||
| LocalName = child.Name | |||||
| }); | |||||
| } | |||||
| slot_variables.TryGetValue(trackable, out var slot_variable); | |||||
| trackable_data.Add(new TrackableData( | |||||
| trackable: trackable, | |||||
| node_id: node_ids[trackable], | |||||
| object_name: object_names[trackable], | |||||
| children_proto: children_proto, | |||||
| slot_variable_proto: slot_variable??new pbc.RepeatedField<TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>(), | |||||
| object_to_save: CheckPointUtils.get_mapped_trackable(trackable, object_map) | |||||
| )); | |||||
| } | |||||
| return (trackable_data, node_ids); | |||||
| } | |||||
| private static TrackableObjectGraph fill_object_graph_proto(IList<TrackableData> trackable_data) | |||||
| { | |||||
| TrackableObjectGraph object_graph_proto = new(); | |||||
| for(int i = 0; i < trackable_data.Count; i++) | |||||
| { | |||||
| var td = trackable_data[i]; | |||||
| Debug.Assert(td.node_id == i); | |||||
| object_graph_proto.Nodes.Add(new TrackableObjectGraph.Types.TrackableObject(td.slot_variable_proto, td.children_proto)); | |||||
| } | |||||
| return object_graph_proto; | |||||
| } | |||||
| /// <summary> | |||||
| /// Creates dictionary of tensors to checkpoint, and updates the proto. | |||||
| /// </summary> | |||||
| /// <param name="tensor_trackables"></param> | |||||
| /// <param name="node_ids"></param> | |||||
| /// <param name="call_with_mapped_captures"></param> | |||||
| /// <param name="cache"></param> | |||||
| /// <param name="object_graph_proto"></param> | |||||
| private static IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids, | |||||
| bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) | |||||
| { | |||||
| Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new(); | |||||
| foreach(var td in tensor_trackables) | |||||
| { | |||||
| // TODO: deal with cache. | |||||
| var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; | |||||
| Trackable trackable = null; | |||||
| IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict; | |||||
| if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) | |||||
| { | |||||
| (trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); | |||||
| } | |||||
| else | |||||
| { | |||||
| tensor_dict = get_tensors_from_trackable(td, call_with_mapped_captures, object_graph_proto); | |||||
| trackable = td.object_to_save; | |||||
| } | |||||
| if(trackable is not null) | |||||
| { | |||||
| serialized_tensors[trackable] = tensor_dict; | |||||
| } | |||||
| else | |||||
| { | |||||
| serialized_tensors[Trackable.None] = tensor_dict; | |||||
| } | |||||
| } | |||||
| return serialized_tensors; | |||||
| } | |||||
| private static IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||||
| { | |||||
| var trackable = trackable_data.object_to_save; | |||||
| // TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. | |||||
| IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> ret_tensor_dict; | |||||
| if (call_with_mapped_captures) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| else | |||||
| { | |||||
| ret_tensor_dict = trackable.serialize_to_tensors(); | |||||
| } | |||||
| // TODO: deal with the type `SaveSpce` (currently it will never be it). | |||||
| Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new(); | |||||
| foreach(var pair in ret_tensor_dict) | |||||
| { | |||||
| var local_name = TrackableUtils.escape_local_name(pair.Key); | |||||
| var maybe_tensor = pair.Value; | |||||
| var checkpoint_key = TrackableUtils.checkpoint_key(trackable_data.object_name, local_name); | |||||
| tensor_dict[checkpoint_key] = maybe_tensor; | |||||
| if(maybe_tensor.IsTypeOrDeriveFrom<SaveSpec>()) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| //((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; | |||||
| } | |||||
| if(object_graph_proto is not null) | |||||
| { | |||||
| object_graph_proto.Nodes[trackable_data.node_id].Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor() | |||||
| { | |||||
| Name = local_name, | |||||
| CheckpointKey = checkpoint_key, | |||||
| FullName = CheckPointUtils.get_full_name(trackable) | |||||
| }); | |||||
| } | |||||
| } | |||||
| return tensor_dict; | |||||
| } | |||||
| /// <summary> | |||||
| /// Gets tensors to serialize from a Trackable with legacy SaveableObjects. | |||||
| /// </summary> | |||||
| /// <param name="trackable_data"></param> | |||||
| /// <param name="node_ids"></param> | |||||
| /// <param name="call_with_mapped_captures"></param> | |||||
| /// <param name="object_graph_proto"></param> | |||||
| /// <returns></returns> | |||||
| private static (Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids, | |||||
| bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||||
| { | |||||
| Dictionary<Trackable, string> object_names = new(); | |||||
| object_names[trackable_data.trackable] = trackable_data.object_name; | |||||
| Dictionary<Trackable, Trackable> object_map = new(); | |||||
| object_map[trackable_data.trackable] = trackable_data.object_to_save; | |||||
| var (checkpoint_factory_map, _) = SaveUtilV1.get_checkpoint_factories_and_keys(object_names, object_map); | |||||
| var (named_saveable_objects, _) = SaveUtilV1.generate_saveable_objects(checkpoint_factory_map, object_graph_proto, node_ids, object_map, | |||||
| call_with_mapped_captures, saveables_cache: null); | |||||
| var trackable = new SaveableCompatibilityConverter(trackable_data.object_to_save, named_saveable_objects); | |||||
| return (trackable, trackable.serialize_to_tensors()); | |||||
| } | |||||
| private static IDictionary<string, IDictionary<string, Trackable>> get_and_write_registered_savers(IDictionary<string, IList<TrackableData>> registered_trackables, TrackableObjectGraph object_graph_proto) | |||||
| { | |||||
| Dictionary<string, IDictionary<string, Trackable>> registered_savers = new(); | |||||
| foreach(var pair in registered_trackables) | |||||
| { | |||||
| foreach(var td in pair.Value) | |||||
| { | |||||
| if (registered_savers.ContainsKey(pair.Key)) | |||||
| { | |||||
| registered_savers[pair.Key] = new Dictionary<string, Trackable>(); | |||||
| } | |||||
| else | |||||
| { | |||||
| registered_savers[pair.Key][td.object_name] = td.object_to_save; | |||||
| } | |||||
| var object_proto = object_graph_proto.Nodes[td.node_id]; | |||||
| // TODO: add APIs and complete it. Now the `TrackableObjectGraph.Types.TrackableObject` lacks `registered_savers`. | |||||
| } | |||||
| } | |||||
| return registered_savers; | |||||
| } | |||||
| private static (IList<TrackableData>, IList<TrackableData>, IDictionary<string, IList<TrackableData>>) split_trackables(IEnumerable<TrackableData> trackable_data) | |||||
| { | |||||
| List<TrackableData> tensor_trackables = new(); | |||||
| List<TrackableData> py_state_trackables = new(); // skip the process of `PyState` for the lack of API. This is only a pleceholder. | |||||
| Dictionary<string, IList<TrackableData>> registered_trackables = new(); | |||||
| foreach(var td in trackable_data) | |||||
| { | |||||
| // TODO: deal with registration. | |||||
| tensor_trackables.Add(td); | |||||
| } | |||||
| return (tensor_trackables, py_state_trackables, registered_trackables); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,223 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Diagnostics; | |||||
| using System.Linq; | |||||
| using Tensorflow.Exceptions; | |||||
| using Tensorflow.Train; | |||||
| using Tensorflow.Training; | |||||
| using pbc = global::Google.Protobuf.Collections; | |||||
| using static Tensorflow.Binding; | |||||
| using Google.Protobuf; | |||||
| namespace Tensorflow.Checkpoint; | |||||
| public static class SaveUtilV1 | |||||
| { | |||||
| public static (IDictionary<Trackable, IEnumerable<CheckpointFactoryData>>, object?) get_checkpoint_factories_and_keys(IDictionary<Trackable, string> object_names, | |||||
| IDictionary<Trackable, Trackable>? object_map = null) | |||||
| { | |||||
| // According to https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/registration/README.md, | |||||
| // till now only internal registrations are allowed. So, we won't return a saver in this function. | |||||
| // The implementation of this function should be updated if tensorflow update it. | |||||
| Dictionary<Trackable, IEnumerable<CheckpointFactoryData>> checkpoint_factory_map = new(); | |||||
| foreach (var pair in object_names) | |||||
| { | |||||
| var trackable = pair.Key; | |||||
| var object_name = pair.Value; | |||||
| var object_to_save = CheckPointUtils.get_mapped_trackable(trackable, object_map); | |||||
| // skip the registration process. | |||||
| List<CheckpointFactoryData> current_list = new(); | |||||
| foreach (var name_and_factory in saveable_object_util.saveable_objects_from_trackable(object_to_save)) | |||||
| { | |||||
| // treat name as key_suffix. | |||||
| var name = name_and_factory.Key; | |||||
| var checkpoint_key = TrackableUtils.checkpoint_key(object_name, name); | |||||
| current_list.Add(new CheckpointFactoryData(name_and_factory.Value, name, checkpoint_key)); | |||||
| } | |||||
| checkpoint_factory_map[trackable] = current_list; | |||||
| } | |||||
| return (checkpoint_factory_map, null); | |||||
| } | |||||
| public static (IList<MySaveableObject>, IDictionary<string, IDictionary<string, Trackable>>?) frozen_saveables_and_savers(ObjectGraphView graph_view, | |||||
| IDictionary<Trackable, Trackable> object_map, Graph? to_graph, bool call_with_mapped_captures, | |||||
| object? saveables_cache = null) | |||||
| { | |||||
| if (to_graph is not null) | |||||
| { | |||||
| var g = to_graph.as_default(); | |||||
| var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, | |||||
| object_map, call_with_mapped_captures, saveables_cache); | |||||
| tf.device("/cpu:0"); | |||||
| var object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); | |||||
| named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | |||||
| g.Exit(); | |||||
| return (named_saveable_objects, registered_savers); | |||||
| } | |||||
| else | |||||
| { | |||||
| using (new ops.NullContextManager()) | |||||
| { | |||||
| var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, | |||||
| object_map, call_with_mapped_captures, saveables_cache); | |||||
| tf.device("/cpu:0"); | |||||
| var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); | |||||
| named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | |||||
| return (named_saveable_objects, registered_savers); | |||||
| } | |||||
| } | |||||
| } | |||||
| public static (IList<MySaveableObject>, TrackableObjectGraph, object?, IDictionary<string, IDictionary<string, Trackable>>?) serialize_gathered_objects(ObjectGraphView graph_view, | |||||
| IDictionary<Trackable, Trackable> object_map, bool call_with_mapped_captures, object? saveables_cache = null) | |||||
| { | |||||
| var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); | |||||
| Dictionary<Trackable, string> object_names = new(); | |||||
| foreach (var pair in node_paths) | |||||
| { | |||||
| object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value); | |||||
| } | |||||
| Dictionary<Trackable, int> node_ids = new(); | |||||
| for (int i = 0; i < trackable_objects.Count; i++) | |||||
| { | |||||
| node_ids[trackable_objects[i]] = i; | |||||
| } | |||||
| var slot_variables = CheckPointUtils.serialize_slot_variables(trackable_objects, node_ids, object_names); | |||||
| var object_graph_proto = fill_object_graph_proto(graph_view, trackable_objects, node_ids, slot_variables); | |||||
| var (named_saveable_objects, feed_additions, registered_savers) = add_attributes_to_object_graph( | |||||
| trackable_objects, object_graph_proto, node_ids, object_names, object_map, call_with_mapped_captures, | |||||
| saveables_cache); | |||||
| CheckPointUtils.add_checkpoint_values_check(object_graph_proto); | |||||
| return (named_saveable_objects, object_graph_proto, feed_additions, registered_savers); | |||||
| } | |||||
| private static TrackableObjectGraph fill_object_graph_proto(ObjectGraphView graph_view, IList<Trackable> trackable_objects, | |||||
| IDictionary<Trackable, int> node_ids, | |||||
| IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>> | |||||
| slot_variables) | |||||
| { | |||||
| TrackableObjectGraph object_graph_proto = new(); | |||||
| for (int i = 0; i < trackable_objects.Count; i++) | |||||
| { | |||||
| var trackable = trackable_objects[i]; | |||||
| Debug.Assert(node_ids[trackable] == i); | |||||
| TrackableObjectGraph.Types.TrackableObject object_proto; | |||||
| if (slot_variables.TryGetValue(trackable, out var slots)) | |||||
| { | |||||
| object_proto = new TrackableObjectGraph.Types.TrackableObject(slots); | |||||
| } | |||||
| else | |||||
| { | |||||
| object_proto = new TrackableObjectGraph.Types.TrackableObject(); | |||||
| } | |||||
| object_graph_proto.Nodes.Add(object_proto); | |||||
| foreach (var child in graph_view.list_children(trackable)) | |||||
| { | |||||
| object_proto.Children.Add(new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference() | |||||
| { NodeId = node_ids[child.Refer], LocalName = child.Name }); | |||||
| } | |||||
| } | |||||
| return object_graph_proto; | |||||
| } | |||||
| private static (IList<MySaveableObject>, object?, IDictionary<string, IDictionary<string, Trackable>>?) add_attributes_to_object_graph( | |||||
| IList<Trackable> trackable_objects, | |||||
| TrackableObjectGraph object_graph_proto, IDictionary<Trackable, int> node_ids, | |||||
| IDictionary<Trackable, string> object_names, IDictionary<Trackable, Trackable> object_map, | |||||
| bool call_with_mapped_captures, object? saveables_cache = null) | |||||
| { | |||||
| int cnt = Math.Min(trackable_objects.Count, object_graph_proto.Nodes.Count); | |||||
| for (int i = 0; i < cnt; i++) | |||||
| { | |||||
| Debug.Assert(node_ids[trackable_objects[i]] == i); | |||||
| } | |||||
| var (checkpoint_factory_map, unmmaped_registered_savers) = | |||||
| get_checkpoint_factories_and_keys(object_names, object_map); | |||||
| // skip the process of registered savers | |||||
| var (named_saveable_objects, feed_additions) = generate_saveable_objects(checkpoint_factory_map, | |||||
| object_graph_proto, node_ids, object_map, call_with_mapped_captures, saveables_cache); | |||||
| return (named_saveable_objects, feed_additions, null); | |||||
| } | |||||
| public static (IList<MySaveableObject>, object?) generate_saveable_objects( | |||||
| IDictionary<Trackable, IEnumerable<CheckpointFactoryData>> checkpoint_factory_map, | |||||
| TrackableObjectGraph? object_graph_proto, IDictionary<Trackable, int>? node_ids, | |||||
| IDictionary<Trackable, Trackable> object_map, bool call_with_mapped_captures, object? saveables_cache = null) | |||||
| { | |||||
| List<MySaveableObject> named_saveable_objects = new(); | |||||
| foreach (var pair in checkpoint_factory_map) | |||||
| { | |||||
| var trackable = pair.Key; | |||||
| var factory_data_list = pair.Value; | |||||
| bool fill_object_proto = object_graph_proto is not null && node_ids is not null; | |||||
| TrackableObjectGraph.Types.TrackableObject object_proto = null!; | |||||
| if (fill_object_proto) | |||||
| { | |||||
| object_proto = object_graph_proto.Nodes[node_ids[trackable]]; | |||||
| } | |||||
| var object_to_save = CheckPointUtils.get_mapped_trackable(trackable, object_map); | |||||
| // skip cache | |||||
| foreach (var factory_data in factory_data_list) | |||||
| { | |||||
| var name = factory_data.name; | |||||
| var key = factory_data.checkpoint_key; | |||||
| var maybe_saveable = factory_data.factory; | |||||
| // TODO: oneflow python has a process with callable `saveable_factory`. | |||||
| List<MySaveableObject> saveables = new(); | |||||
| if (maybe_saveable.TryGet<MySaveableObject>(out var s)) | |||||
| { | |||||
| saveables.Add(s); | |||||
| } | |||||
| else | |||||
| { | |||||
| saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValue<BaseResourceVariable>() as Trackable, key)); | |||||
| } | |||||
| foreach (var saveable in saveables) | |||||
| { | |||||
| if (!saveable.name.Contains(key)) | |||||
| { | |||||
| throw new AssertionError($"The object {trackable} produced a SaveableObject with name " + | |||||
| $"'{saveable.name}' for attribute '{name}'. Expected a name" + | |||||
| $" containing '{key}'."); | |||||
| } | |||||
| } | |||||
| // skip the process of PythonState | |||||
| named_saveable_objects.AddRange(saveables); | |||||
| if(!fill_object_proto) continue; | |||||
| // skip the process of `TrackableSaveable` because of lack of APIs. | |||||
| object_proto!.Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor() | |||||
| { Name = name, CheckpointKey = key, FullName = CheckPointUtils.get_full_name(object_to_save) }); | |||||
| } | |||||
| } | |||||
| return (named_saveable_objects, null); | |||||
| } | |||||
| } | |||||
| public record class CheckpointFactoryData | |||||
| ( | |||||
| Maybe<BaseResourceVariable, MySaveableObject> factory, | |||||
| string name, | |||||
| string checkpoint_key | |||||
| ); | |||||
| @@ -0,0 +1,16 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Train; | |||||
| namespace Tensorflow.Checkpoint | |||||
| { | |||||
| internal static class SaveableCompat | |||||
| { | |||||
| public static string? get_saveable_name(Trackable cls_or_obj) | |||||
| { | |||||
| // TODO: implement it with Attribute. | |||||
| return null; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,82 @@ | |||||
| using System; | |||||
| using Tensorflow.Train; | |||||
| using System.Collections.Generic; | |||||
| using System.IO; | |||||
| using Tensorflow.Keras.Saving.SavedModel; | |||||
| namespace Tensorflow.Checkpoint; | |||||
| public class TrackableView | |||||
| { | |||||
| protected WeakReference<Trackable> _root_ref; | |||||
| public TrackableView(Trackable obj) | |||||
| { | |||||
| _root_ref = new WeakReference<Trackable>(obj); | |||||
| } | |||||
| public TrackableView(WeakReference<Trackable> obj) | |||||
| { | |||||
| _root_ref = obj; | |||||
| } | |||||
| public virtual IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||||
| { | |||||
| obj._maybe_initialize_trackable(); | |||||
| Dictionary<string, Trackable> children = new(); | |||||
| // Note: in python the return type of `Trackable._trackable_children` is not fixed. | |||||
| // Therefore it uses `convert_to_trackable` to have an extra process. | |||||
| foreach (var pair in obj._trackable_children(save_type, cache)) | |||||
| { | |||||
| children[pair.Key] = pair.Value; | |||||
| } | |||||
| return children; | |||||
| } | |||||
| public Trackable Root | |||||
| { | |||||
| get | |||||
| { | |||||
| if (_root_ref.TryGetTarget(out Trackable res)) | |||||
| { | |||||
| return res; | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new InvalidDataException( | |||||
| "Cannot get the object from the weak reference. Please consider if a null reference is passed to the constructor."); | |||||
| } | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Returns a list of all nodes and its paths from self.root using a breadth first traversal. | |||||
| /// Corresponding to tensorflow/python/checkpoint/trackable_view.Trackable._descendants_with_paths | |||||
| /// </summary> | |||||
| protected (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) _descendants_with_paths() | |||||
| { | |||||
| List<Trackable> bfs_sorted = new(); | |||||
| Queue<Trackable> to_visit = new(); | |||||
| to_visit.Enqueue(Root); | |||||
| Dictionary<Trackable, IEnumerable<TrackableReference>> node_paths = new(); | |||||
| node_paths[this.Root] = new List<TrackableReference>(); | |||||
| while (!to_visit.empty()) | |||||
| { | |||||
| var current_trackable = to_visit.Dequeue(); | |||||
| bfs_sorted.Add(current_trackable); | |||||
| var children_dict = this.children(current_trackable); | |||||
| foreach (var name in children_dict.Keys) | |||||
| { | |||||
| var dependency = children_dict[name]; | |||||
| if (!node_paths.ContainsKey(dependency)) | |||||
| { | |||||
| var list = new List<TrackableReference>(node_paths[current_trackable]); | |||||
| list.Add(new TrackableReference(name, dependency)); | |||||
| node_paths[dependency] = list; | |||||
| to_visit.Enqueue(dependency); | |||||
| } | |||||
| } | |||||
| } | |||||
| return (bfs_sorted, node_paths); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,195 @@ | |||||
| using Google.Protobuf; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Diagnostics; | |||||
| using System.Linq; | |||||
| using Tensorflow.Contexts; | |||||
| using Tensorflow.Eager; | |||||
| using Tensorflow.Train; | |||||
| using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Checkpoint; | |||||
| /// <summary> | |||||
| /// Saves and restores a `Trackable` object and its dependencies. | |||||
| /// </summary> | |||||
| public class TrackableSaver | |||||
| { | |||||
| private ObjectGraphView _graph_view; | |||||
| private Tensor _cached_save_operation; | |||||
| private TrackableObjectGraph _last_save_object_graph; | |||||
| private Tensor? _object_graph_feed_tensor = null; | |||||
| private Tensor? _file_prefix_feed_tensor = null; | |||||
| private Dictionary<Trackable, Trackable>? _object_map = null; | |||||
| private object? _cache = null; | |||||
| public TrackableSaver(ObjectGraphView graph_view) | |||||
| { | |||||
| _graph_view = graph_view; | |||||
| // TODO: cache when not executing eagerly. | |||||
| // including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`, | |||||
| // `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache` | |||||
| } | |||||
| private (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||||
| gather_serialized_tensors(Tensor? object_graph_tensor = null) | |||||
| { | |||||
| var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); | |||||
| // TODO: cache. | |||||
| if(object_graph_tensor is null) | |||||
| { | |||||
| tf.device("/cpu:0"); | |||||
| object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); | |||||
| } | |||||
| else | |||||
| { | |||||
| feed_additions[object_graph_tensor] = graph_proto.ToByteArray(); | |||||
| } | |||||
| Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | |||||
| if (!serialized_tensors.ContainsKey(Trackable.None)) | |||||
| { | |||||
| serialized_tensors[Trackable.None] = new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>(); | |||||
| } | |||||
| serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor; | |||||
| return (serialized_tensors, feed_additions, registered_savers, graph_proto); | |||||
| } | |||||
| private (Tensor, IDictionary<Tensor, object>) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options) | |||||
| { | |||||
| var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor); | |||||
| Func<(Tensor, IDictionary<Tensor, object>)> run_save = () => | |||||
| { | |||||
| if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function()) | |||||
| { | |||||
| var saver = new MultiDeviceSaver(serialized_tensors, registered_savers); | |||||
| var save_op = saver.save(file_prefix, options); | |||||
| // tensorflow python: `with ops.device("/cpu:0"):` | |||||
| using (ops.control_dependencies(new object[] { save_op })) | |||||
| { | |||||
| _cached_save_operation = array_ops.identity(file_prefix); | |||||
| } | |||||
| _last_save_object_graph = graph_proto; | |||||
| } | |||||
| return (_cached_save_operation, feed_additions); | |||||
| }; | |||||
| if (options.experimental_enable_async_checkpoint) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| return run_save(); | |||||
| } | |||||
| private (Tensor, IDictionary<Tensor, object>) save_cached_when_graph_building(string file_prefix, Tensor object_graph_tensor, CheckpointOptions options) | |||||
| { | |||||
| var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor); | |||||
| Func<(Tensor, IDictionary<Tensor, object>)> run_save = () => | |||||
| { | |||||
| if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function()) | |||||
| { | |||||
| var saver = new MultiDeviceSaver(serialized_tensors, registered_savers); | |||||
| var save_op = saver.save(file_prefix, options); | |||||
| // tensorflow python: `with ops.device("/cpu:0"):` | |||||
| using (ops.control_dependencies(new object[] {save_op} )) | |||||
| { | |||||
| _cached_save_operation = array_ops.identity(tf.constant(file_prefix)); | |||||
| } | |||||
| _last_save_object_graph = graph_proto; | |||||
| } | |||||
| return (_cached_save_operation, feed_additions); | |||||
| }; | |||||
| if (options.experimental_enable_async_checkpoint) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| return run_save(); | |||||
| } | |||||
| // TODO: parameter write_done_callback | |||||
| public Tensor save(string file_prefix, int? checkpoint_number = null, Session? session = null, | |||||
| CheckpointOptions? options = null) | |||||
| { | |||||
| if (options is null) | |||||
| { | |||||
| options = new CheckpointOptions(); | |||||
| } | |||||
| Dictionary<Tensor, object> feed_dict = new(); | |||||
| bool use_session = (!tf.Context.executing_eagerly() && !ops.inside_function()); | |||||
| if (checkpoint_number is not null) | |||||
| { | |||||
| file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}"; | |||||
| } | |||||
| Tensor file_prefix_tensor; | |||||
| Tensor object_graph_tensor; | |||||
| string file_prefix_to_save; | |||||
| if (use_session) | |||||
| { | |||||
| if (_object_graph_feed_tensor is null) | |||||
| { | |||||
| // In python there is `with ops.device("/cpu:0")`. | |||||
| _object_graph_feed_tensor = constant_op.constant("", TF_DataType.TF_STRING); | |||||
| _file_prefix_feed_tensor = constant_op.constant("", TF_DataType.TF_STRING); | |||||
| } | |||||
| object_graph_tensor = _object_graph_feed_tensor; | |||||
| file_prefix_tensor = _file_prefix_feed_tensor; | |||||
| feed_dict[file_prefix_tensor] = file_prefix; | |||||
| file_prefix_to_save = ""; | |||||
| } | |||||
| else | |||||
| { | |||||
| // In python there is `with ops.device("/cpu:0")`. | |||||
| file_prefix_tensor = ops.convert_to_tensor(file_prefix, TF_DataType.TF_STRING); | |||||
| object_graph_tensor = null; | |||||
| file_prefix_to_save = file_prefix; | |||||
| } | |||||
| var (save_path, new_feed_additions) = | |||||
| save_cached_when_graph_building(file_prefix_to_save, object_graph_tensor, options); | |||||
| if (new_feed_additions is not null) | |||||
| { | |||||
| foreach (var pair in new_feed_additions) | |||||
| { | |||||
| feed_dict.Add(pair.Key, pair.Value); | |||||
| } | |||||
| } | |||||
| if(!use_session) | |||||
| { | |||||
| session = null; | |||||
| } | |||||
| else if (session is null) | |||||
| { | |||||
| session = new Session(); // In python it uses `get_session`. | |||||
| } | |||||
| if (session is not null) | |||||
| { | |||||
| var s = feed_dict.Select(x => new FeedItem(x.Key, x.Value)).ToArray(); | |||||
| return session.run((Tensor)save_path, s); | |||||
| } | |||||
| else if (use_session) | |||||
| { | |||||
| throw new RuntimeError($"Unable to save checkpoint to \"{file_prefix}\" " + | |||||
| "in graph mode without a default session. Please use " + | |||||
| "`with tf.Session():` to create a session."); | |||||
| } | |||||
| else | |||||
| { | |||||
| return save_path; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,540 @@ | |||||
| using System; | |||||
| using System.Buffers.Text; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Train; | |||||
| using static Tensorflow.ApiDef.Types; | |||||
| using static Tensorflow.CostGraphDef.Types; | |||||
| using static Tensorflow.OptimizerOptions.Types; | |||||
| using static Tensorflow.Binding; | |||||
| using System.Text.RegularExpressions; | |||||
| using System.Linq; | |||||
| using Tensorflow.Operations; | |||||
| using Tensorflow.Training; | |||||
| using Tensorflow.Graphs; | |||||
| using System.Xml.Linq; | |||||
| using System.Diagnostics; | |||||
| using RestoreFunc = System.Func<object, object>; | |||||
| namespace Tensorflow.Checkpoint | |||||
| { | |||||
| public class Maybe<TA, TB> | |||||
| { | |||||
| private TA? _valueA = default(TA); | |||||
| private TB? _valueB = default(TB); | |||||
| private Type _type; | |||||
| private bool _assignedTA; | |||||
| public Maybe(TA value) | |||||
| { | |||||
| _valueA = value; | |||||
| _type= typeof(TA); | |||||
| _assignedTA = true; | |||||
| } | |||||
| public Maybe(TB value) | |||||
| { | |||||
| _valueB = value; | |||||
| _type = typeof(TB); | |||||
| _assignedTA = false; | |||||
| } | |||||
| public Type DataType => _type; | |||||
| /// <summary> | |||||
| /// Try to get the type T member of this instance. It returns true when TA or TB derive from T and is correspondingly assigned. | |||||
| /// It returns | |||||
| /// </summary> | |||||
| /// <typeparam name="T"></typeparam> | |||||
| /// <param name="res"></param> | |||||
| /// <returns></returns> | |||||
| public bool TryGet<T>(out T? res) | |||||
| { | |||||
| if(_valueA is T && _valueB is not T) | |||||
| { | |||||
| res = (T)(object)_valueA; | |||||
| return _assignedTA; | |||||
| } | |||||
| else if(_valueA is not T && _valueB is T) | |||||
| { | |||||
| res = (T)(object)_valueB; | |||||
| return !_assignedTA; | |||||
| } | |||||
| res = default(T); | |||||
| return false; | |||||
| } | |||||
| public bool IsTypeOrDeriveFrom<T>() | |||||
| { | |||||
| if (_valueA is T && _valueB is not T) | |||||
| { | |||||
| return _assignedTA; | |||||
| } | |||||
| else if (_valueA is not T && _valueB is T) | |||||
| { | |||||
| return !_assignedTA; | |||||
| } | |||||
| else if (_valueA is T && _valueB is T) | |||||
| { | |||||
| return true; | |||||
| } | |||||
| else | |||||
| { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| public T GetValue<T>() | |||||
| { | |||||
| if (_valueA is T && _valueB is not T) | |||||
| { | |||||
| return (T)(object)_valueA; | |||||
| } | |||||
| else if (_valueA is not T && _valueB is T) | |||||
| { | |||||
| return (T)(object)_valueB; | |||||
| } | |||||
| else if (_valueA is T && _valueB is T) | |||||
| { | |||||
| throw new TypeError("The type is vague, this is always because TA and TB both derive from T."); | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new TypeError($"Expected {typeof(TA)} or {typeof(TB)}, but got typeof{typeof(T)}."); | |||||
| } | |||||
| } | |||||
| public static implicit operator Maybe<TA, TB>(TA a) | |||||
| { | |||||
| return new Maybe<TA, TB>(a); | |||||
| } | |||||
| public static implicit operator Maybe<TA, TB>(TB b) | |||||
| { | |||||
| return new Maybe<TA, TB>(b); | |||||
| } | |||||
| } | |||||
| internal class SingleDeviceSaver | |||||
| { | |||||
| private IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> _tensor_slice_dict; | |||||
| public SingleDeviceSaver(IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> tensor_slice_dict) | |||||
| { | |||||
| _tensor_slice_dict = tensor_slice_dict; | |||||
| } | |||||
| public SingleDeviceSaver(IDictionary<string, IDictionary<string, Tensor>> tensor_slice_dict) | |||||
| { | |||||
| _tensor_slice_dict = tensor_slice_dict.ToDictionary( | |||||
| x => x.Key, x => x.Value.ToDictionary( | |||||
| y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value)) | |||||
| as IDictionary<string, Maybe<Tensor, SaveSpec>>); | |||||
| } | |||||
| public SingleDeviceSaver(IDictionary<string, IDictionary<string, SaveSpec>> tensor_slice_dict) | |||||
| { | |||||
| _tensor_slice_dict = tensor_slice_dict.ToDictionary( | |||||
| x => x.Key, x => x.Value.ToDictionary( | |||||
| y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value)) | |||||
| as IDictionary<string, Maybe<Tensor, SaveSpec>>); | |||||
| } | |||||
| public Operation? save(Tensor file_prefix, CheckpointOptions? options = null) | |||||
| { | |||||
| if(options is null) | |||||
| { | |||||
| options = new CheckpointOptions(); | |||||
| } | |||||
| List<string> tensor_names = new(); | |||||
| List<Tensor> tensors = new(); | |||||
| List<string> slice_specs = new(); | |||||
| foreach(var pair in _tensor_slice_dict) | |||||
| { | |||||
| var checkpoint_key = pair.Key; | |||||
| var tensor_slices = pair.Value; | |||||
| foreach(var slice in tensor_slices) | |||||
| { | |||||
| var slice_spec = slice.Key; | |||||
| var maybe_tensor = slice.Value; | |||||
| if(maybe_tensor.TryGet<SaveSpec>(out var spec)) | |||||
| { | |||||
| var tensor_value = spec.tensor; | |||||
| if (tensor_value is not null) | |||||
| { | |||||
| tensor_names.Add(spec.name); | |||||
| tensors.Add(tensor_value); | |||||
| slice_specs.Add(spec.slice_spec); | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| var tensor = maybe_tensor.GetValue<Tensor>(); | |||||
| tensor_names.Add(checkpoint_key); | |||||
| tensors.Add(tensor); | |||||
| slice_specs.Add(slice_spec); | |||||
| } | |||||
| } | |||||
| } | |||||
| // TODO: specify the device. | |||||
| return tf.io.save_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensors.ToArray()); | |||||
| } | |||||
| public Operation? save(string file_prefix, CheckpointOptions? options = null) => save(tf.constant(file_prefix, TF_DataType.TF_STRING), options); | |||||
| public IDictionary<string, IDictionary<string, Tensor>> restore(Tensor file_prefix, CheckpointOptions? options = null) | |||||
| { | |||||
| if(options is null) | |||||
| { | |||||
| options = new CheckpointOptions(); | |||||
| } | |||||
| List<string> tensor_names = new(); | |||||
| List<TF_DataType> tensor_dtypes = new(); | |||||
| List<string> slice_specs = new(); | |||||
| foreach(var pair in _tensor_slice_dict) | |||||
| { | |||||
| var checkpoint_key = pair.Key; | |||||
| var tensor_slices = pair.Value; | |||||
| foreach(var slice in tensor_slices) | |||||
| { | |||||
| var slice_spec = slice.Key; | |||||
| var maybe_tensor = slice.Value; | |||||
| // TODO: deal with other types. Currently only `SaveSpec` is allowed. | |||||
| if(maybe_tensor.TryGet<SaveSpec>(out var spec)) | |||||
| { | |||||
| tensor_dtypes.Add(spec.dtype); | |||||
| slice_specs.Add(spec.slice_spec); | |||||
| tensor_names.Add(spec.name); | |||||
| } | |||||
| else | |||||
| { | |||||
| var tensor = maybe_tensor.GetValue<Tensor>(); | |||||
| tensor_dtypes.Add(tensor.dtype); | |||||
| slice_specs.Add(slice_spec); | |||||
| tensor_names.Add(checkpoint_key); | |||||
| } | |||||
| } | |||||
| } | |||||
| string restore_device = string.IsNullOrEmpty(options.experimental_io_device) ? "cpu:0": options.experimental_io_device!; | |||||
| // tf python has code `with ops.device(restore_device):` here. | |||||
| tf.device(restore_device); // may be risky. | |||||
| var restored_tensors = tf.io.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); | |||||
| Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new(); | |||||
| int idx = 0; | |||||
| foreach(var pair in _tensor_slice_dict) | |||||
| { | |||||
| var checkpoint_key = pair.Key; | |||||
| var tensor_slices = pair.Value; | |||||
| foreach(var slice_spec in tensor_slices.Keys) | |||||
| { | |||||
| var restored_tensor = restored_tensors[idx++]; | |||||
| if (!restored_tensor_dict.ContainsKey(checkpoint_key)) | |||||
| { | |||||
| restored_tensor_dict[checkpoint_key] = new Dictionary<string, Tensor>(); | |||||
| } | |||||
| restored_tensor_dict[checkpoint_key][slice_spec] = restored_tensor; | |||||
| } | |||||
| } | |||||
| return restored_tensor_dict; | |||||
| } | |||||
| public IDictionary<string, IDictionary<string, Tensor>> restore(string file_prefix, CheckpointOptions? options = null) => restore(tf.constant(file_prefix)); | |||||
| } | |||||
| /// <summary> | |||||
| /// Saves checkpoints directly from multiple devices. | |||||
| /// Note that this is a low-level utility which stores Tensors in the keys | |||||
| /// specified by `SaveableObject`s.Higher-level utilities for object-based | |||||
| /// checkpointing are built on top of it. | |||||
| /// </summary> | |||||
| public class MultiDeviceSaver | |||||
| { | |||||
| private Dictionary<string, SingleDeviceSaver> _single_device_savers; | |||||
| private IDictionary<string, (RestoreFunc, RestoreFunc)> _registered_savers; | |||||
| private Dictionary<(string, string), RestoreFunc> _keys_to_restore_fn; | |||||
| private Dictionary<RestoreFunc, IList<(string, string)>> _restore_fn_to_keys; | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="serialized_tensors"> A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. </param> | |||||
| /// <param name="registered_savers"></param> | |||||
| /// <param name="call_with_mapped_capture"></param> | |||||
| public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors, | |||||
| IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false) | |||||
| { | |||||
| _keys_to_restore_fn = new Dictionary<(string, string), RestoreFunc>(); | |||||
| _restore_fn_to_keys = new Dictionary<RestoreFunc, IList<(string, string)>>(); | |||||
| Dictionary<string, IDictionary<string, IDictionary<string, Tensor>>> tensors_by_device= new(); | |||||
| foreach(var pair in serialized_tensors) | |||||
| { | |||||
| var obj = pair.Key; | |||||
| var tensor_dict = pair.Value; | |||||
| RestoreFunc restore_fn; | |||||
| if(obj == Trackable.None) | |||||
| { | |||||
| restore_fn = new RestoreFunc(x => null); | |||||
| } | |||||
| else | |||||
| { | |||||
| restore_fn = new RestoreFunc(x => | |||||
| { | |||||
| if(x is IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>) | |||||
| { | |||||
| return obj._restore_from_tensors(x as IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>); | |||||
| } | |||||
| throw new TypeError($"Expected `IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>` as input, got{x.GetType()}."); | |||||
| }); | |||||
| } | |||||
| foreach(var item in tensor_dict) | |||||
| { | |||||
| var checkpoint_key = item.Key; | |||||
| IDictionary<string, Tensor> spec_to_tensor; | |||||
| if(item.Value.TryGet<Tensor>(out var t)) | |||||
| { | |||||
| spec_to_tensor = new Dictionary<string, Tensor>(); | |||||
| spec_to_tensor[""] = t; | |||||
| } | |||||
| else | |||||
| { | |||||
| spec_to_tensor = item.Value.GetValue<IDictionary<string, Tensor>>(); | |||||
| } | |||||
| foreach(var spec in spec_to_tensor) | |||||
| { | |||||
| var slice_spec = spec.Key; | |||||
| var tensor = spec.Value; | |||||
| if(_keys_to_restore_fn.ContainsKey((checkpoint_key, slice_spec))) | |||||
| { | |||||
| throw new ValueError("Recieved multiple tensors with the same checkpoint key and " + | |||||
| $"slice spec. This is invalid because one will overwrite the " + | |||||
| $"other in the checkpoint. This indicates a bug in the Checkpoint key-generation."); | |||||
| } | |||||
| _keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn; | |||||
| _restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec)); | |||||
| // skip the process of device name because lack of API. | |||||
| var host_device = tensor.Device; | |||||
| var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, Tensor>>()); | |||||
| if (!internal_dict.ContainsKey(checkpoint_key)) | |||||
| { | |||||
| internal_dict[checkpoint_key] = new Dictionary<string, Tensor>(); | |||||
| } | |||||
| internal_dict[checkpoint_key][slice_spec] = tensor; | |||||
| } | |||||
| } | |||||
| } | |||||
| _single_device_savers = tensors_by_device.ToDictionary(x => x.Key, x => new SingleDeviceSaver(x.Value)); | |||||
| _registered_savers = new Dictionary<string, (RestoreFunc, RestoreFunc)>(); | |||||
| if(registered_savers is not null && registered_savers.Count > 0) | |||||
| { | |||||
| // TODO: complete the implementation. | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | |||||
| public Operation save(Tensor file_prefix, CheckpointOptions? options= null) | |||||
| { | |||||
| if(options is null) | |||||
| { | |||||
| options = new CheckpointOptions(); | |||||
| } | |||||
| tf.device("CPU"); // may be risky. | |||||
| var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")), | |||||
| constant_op.constant(".part"), constant_op.constant("_temp/part")); | |||||
| var tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix }); | |||||
| IDictionary<string, Tensor> registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x)); | |||||
| Operation save_fn() | |||||
| { | |||||
| List<Tensor> saved_prefixes= new(); | |||||
| foreach(var saver in _registered_savers) | |||||
| { | |||||
| // TODO: implementi it later. | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| int num_shards = _single_device_savers.Count; | |||||
| List<Operation> sharded_saves = new(); | |||||
| var num_shards_tensor = constant_op.constant(num_shards, name: "num_shards"); | |||||
| string? last_device = null; | |||||
| int shard = 0; | |||||
| foreach(var pair in _single_device_savers.OrderBy(x => x.Key)) | |||||
| { | |||||
| var device = pair.Key; | |||||
| var saver = pair.Value; | |||||
| last_device = device; | |||||
| // skip the extra process of device name because of lack of API. | |||||
| tf.device(device); | |||||
| var shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor); | |||||
| saved_prefixes.Add(shard_prefix); | |||||
| sharded_saves.Add(saver.save(shard_prefix, options)); | |||||
| } | |||||
| using (var controller = ops.control_dependencies(sharded_saves.ToArray())) | |||||
| { | |||||
| string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device; | |||||
| tf.device(merge_device); | |||||
| return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true); | |||||
| } | |||||
| } | |||||
| if(tf.Context.executing_eagerly() && _single_device_savers.Count > 1) | |||||
| { | |||||
| // TODO: implement it. Currently `autograph` does not support the function with non parameter. | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| else | |||||
| { | |||||
| return save_fn(); | |||||
| } | |||||
| } | |||||
| public Operation save(string file_prefix, CheckpointOptions? options = null) => save(tf.constant(file_prefix), options); | |||||
| public IDictionary<string, Operation> restore(Tensor file_prefix, CheckpointOptions? options = null) | |||||
| { | |||||
| if(options is null) | |||||
| { | |||||
| options = new CheckpointOptions(); | |||||
| } | |||||
| IDictionary<string, Operation> restore_func() | |||||
| { | |||||
| Dictionary<RestoreFunc, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> restore_fn_inputs = new(); | |||||
| Dictionary<RestoreFunc, int> restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count); | |||||
| Dictionary<string, Operation> restore_ops = new(); | |||||
| foreach(var single_saver in _single_device_savers.OrderBy(x => x.Key)) | |||||
| { | |||||
| var device = single_saver.Key; | |||||
| var saver = single_saver.Value; | |||||
| tf.device(device); | |||||
| var restored_tensor_dict = saver.restore(file_prefix, options); | |||||
| foreach(var pair in restored_tensor_dict) | |||||
| { | |||||
| var checkpoint_key = pair.Key; | |||||
| var slice_and_tensor = pair.Value; | |||||
| foreach(var item in slice_and_tensor) | |||||
| { | |||||
| var slice_spec = item.Key; | |||||
| var tensor = item.Value; | |||||
| var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; | |||||
| var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>()); | |||||
| if (!string.IsNullOrEmpty(slice_spec)) | |||||
| { | |||||
| if (!internal_dict.ContainsKey(checkpoint_key)) | |||||
| { | |||||
| Dictionary<string, Tensor> dict = new(); | |||||
| dict[slice_spec] = tensor; | |||||
| internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(dict); | |||||
| } | |||||
| else | |||||
| { | |||||
| internal_dict[checkpoint_key].GetValue<IDictionary<string, Tensor>>()[slice_spec] = tensor; | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(tensor); | |||||
| } | |||||
| restore_fn_input_count[restore_fn]--; | |||||
| if (restore_fn_input_count[restore_fn] == 0) | |||||
| { | |||||
| Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors = new(); | |||||
| foreach(var input in restore_fn_inputs[restore_fn]) | |||||
| { | |||||
| restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; | |||||
| } | |||||
| var ret = restore_fn.DynamicInvoke(restored_tensors); | |||||
| if(ret is IDictionary<string, Operation>) | |||||
| { | |||||
| var dict = (IDictionary<string, Operation>)ret; | |||||
| restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| foreach(var item in _registered_savers) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| return restore_ops; | |||||
| } | |||||
| // TODO: complete the implementation. Currently skip it because of lack of API. | |||||
| bool has_custom_device_saver = false; | |||||
| if (tf.Context.executing_eagerly() && (_single_device_savers.Count > 1 || has_custom_device_saver)) | |||||
| { | |||||
| // TODO: implement it. Currently `autograph` does not support the function with non parameter. | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| else | |||||
| { | |||||
| return restore_func(); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Serializes to a SaverDef referencing the current graph. | |||||
| /// </summary> | |||||
| public SaverDef to_proto() | |||||
| { | |||||
| var filename_tensor = array_ops.placeholder(TF_DataType.TF_STRING, new int[] { }, "saver_filename"); | |||||
| var traced_save_func = tf.autograph.to_graph(_traced_save, TF_DataType.TF_STRING); | |||||
| var traced_restore_func = tf.autograph.to_graph(_traced_restore, TF_DataType.TF_STRING); | |||||
| var save_tensor = traced_save_func(filename_tensor); | |||||
| var restore_op = traced_restore_func(filename_tensor).op; | |||||
| return new SaverDef() | |||||
| { | |||||
| FilenameTensorName = filename_tensor.name, | |||||
| SaveTensorName = save_tensor.name, | |||||
| RestoreOpName = restore_op.name, | |||||
| Version = SaverDef.Types.CheckpointFormatVersion.V2 | |||||
| }; | |||||
| } | |||||
| private Tensor _traced_save(Tensor file_prefix) | |||||
| { | |||||
| var save_op = save(file_prefix); | |||||
| tf.device("cpu:0"); | |||||
| using (ops.control_dependencies(new object[]{ save_op })) | |||||
| { | |||||
| return array_ops.identity(file_prefix); | |||||
| } | |||||
| } | |||||
| private Tensor _traced_restore(Tensor file_prefix) | |||||
| { | |||||
| var restore_op = restore(file_prefix); | |||||
| tf.device("cpu:0"); | |||||
| using (ops.control_dependencies(restore_op.Values.ToArray())) | |||||
| { | |||||
| return array_ops.identity(file_prefix); | |||||
| } | |||||
| } | |||||
| public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false) | |||||
| { | |||||
| Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new(); | |||||
| foreach (var saveable in saveables) | |||||
| { | |||||
| var trackable = new SaveableCompatibilityConverter(saveable, new List<MySaveableObject>() { saveable }); | |||||
| serialized_tensors[trackable] = trackable.serialize_to_tensors(); | |||||
| } | |||||
| return new MultiDeviceSaver(serialized_tensors, registered_savers, call_with_mapped_captures); | |||||
| } | |||||
| private static Tensor registered_saver_filename(Tensor filename_tensor, string saver_name) | |||||
| { | |||||
| return gen_ops.string_join(new Tensor[] { filename_tensor, constant_op.constant($"-{saver_name}") }); | |||||
| } | |||||
| private static Tensor sharded_filename(Tensor filename_tensor, int shard, Tensor num_shards) | |||||
| { | |||||
| return gen_ops.sharded_filename(filename_tensor, tf.constant(shard), num_shards); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -17,6 +17,7 @@ | |||||
| using System; | using System; | ||||
| using System.Diagnostics.CodeAnalysis; | using System.Diagnostics.CodeAnalysis; | ||||
| using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
| using Tensorflow.Train; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -90,4 +91,71 @@ namespace Tensorflow | |||||
| Dispose(false); | Dispose(false); | ||||
| } | } | ||||
| } | } | ||||
| } | |||||
| public abstract class DisposableTrackableObject: Trackable, IDisposable | |||||
| { | |||||
| protected IntPtr _handle; | |||||
| protected bool _disposed; | |||||
| protected DisposableTrackableObject() | |||||
| { } | |||||
| protected DisposableTrackableObject(IntPtr handle) | |||||
| => _handle = handle; | |||||
| private void Dispose(bool disposing) | |||||
| { | |||||
| if (_disposed) | |||||
| return; | |||||
| //first handle managed, they might use the unmanaged resources. | |||||
| if (disposing) | |||||
| { | |||||
| // dispose managed state (managed objects). | |||||
| DisposeManagedResources(); | |||||
| } | |||||
| // free unmanaged memory | |||||
| if (_handle != IntPtr.Zero) | |||||
| { | |||||
| // Call the appropriate methods to clean up | |||||
| // unmanaged resources here. | |||||
| // If disposing is false, | |||||
| // only the following code is executed. | |||||
| DisposeUnmanagedResources(_handle); | |||||
| _handle = IntPtr.Zero; | |||||
| } | |||||
| // Note disposing has been done. | |||||
| _disposed = true; | |||||
| } | |||||
| /// <summary> | |||||
| /// Dispose any managed resources. | |||||
| /// </summary> | |||||
| /// <remarks>Equivalent to what you would perform inside <see cref="Dispose()"/></remarks> | |||||
| protected virtual void DisposeManagedResources() | |||||
| { } | |||||
| /// <summary> | |||||
| /// Dispose any unmanaged resources related to given <paramref name="handle"/>. | |||||
| /// </summary> | |||||
| protected abstract void DisposeUnmanagedResources(IntPtr handle); | |||||
| public void Dispose() | |||||
| { | |||||
| Dispose(true); | |||||
| // This object will be cleaned up by the Dispose method. | |||||
| // Therefore, you should call GC.SupressFinalize to | |||||
| // take this object off the finalization queue | |||||
| // and prevent finalization code for this object | |||||
| // from executing a second time. | |||||
| GC.SuppressFinalize(this); | |||||
| } | |||||
| ~DisposableTrackableObject() | |||||
| { | |||||
| Dispose(false); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,31 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using System.Xml.Linq; | |||||
| using Tensorflow.Contexts; | |||||
| using static Tensorflow.ApiDef.Types; | |||||
| using static Tensorflow.CostGraphDef.Types; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Eager | |||||
| { | |||||
| internal class execute | |||||
| { | |||||
| public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx) | |||||
| { | |||||
| var v = values.Select(t => ops.convert_to_tensor(t, ctx:ctx)); | |||||
| var types = v.Select(t => t.dtype.as_datatype_enum()); | |||||
| return (types.ToArray(), v.ToArray()); | |||||
| } | |||||
| public static Tensor[] quick_execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null) | |||||
| { | |||||
| string device_name = ctx.DeviceName; | |||||
| ctx.ensure_initialized(); | |||||
| var tensors = tf.Runner.TFE_Execute(ctx, device_name, op_name, inputs, attrs, num_outputs); | |||||
| return tensors; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,14 @@ | |||||
| namespace Tensorflow.Exceptions; | |||||
| public class AssertionError : TensorflowException | |||||
| { | |||||
| public AssertionError() : base() | |||||
| { | |||||
| } | |||||
| public AssertionError(string message) : base(message) | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -304,7 +304,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| private static OpList stripped_op_list_for_graph(GraphDef graph_def) | |||||
| public static OpList stripped_op_list_for_graph(GraphDef graph_def) | |||||
| { | { | ||||
| var used_ops = ops_used_by_graph_def(graph_def); | var used_ops = ops_used_by_graph_def(graph_def); | ||||
| @@ -345,5 +345,89 @@ namespace Tensorflow | |||||
| return used_ops.ToArray(); | return used_ops.ToArray(); | ||||
| } | } | ||||
| private static bool is_default_attr_value(OpDef op_def, string attr_name, AttrValue attr_value) | |||||
| { | |||||
| foreach (var attr_def in op_def.Attr) | |||||
| { | |||||
| if (attr_def.Name == attr_name) | |||||
| { | |||||
| if (attr_def.DefaultValue is null) return false; | |||||
| // TODO: add new c_api `EqualAttrValueWrapper` and complete the check. | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| public static void strip_graph_default_valued_attrs(MetaGraphDef meta_graph_def) | |||||
| { | |||||
| Dictionary<string, FunctionDef> op_name_to_function = new(); | |||||
| foreach (var function_def in meta_graph_def.GraphDef.Library.Function) | |||||
| { | |||||
| op_name_to_function[function_def.Signature.Name] = function_def; | |||||
| } | |||||
| Action<NodeDef> _strip_node_default_valued_attrs = (node_def) => | |||||
| { | |||||
| if (op_name_to_function.ContainsKey(node_def.Op)) return; | |||||
| var op_def = op_def_registry.GetOpDef(node_def.Op); | |||||
| if(op_def is null) return; | |||||
| HashSet<string> attrs_to_strip = new(); | |||||
| foreach (var attr in node_def.Attr) | |||||
| { | |||||
| if (is_default_attr_value(op_def, attr.Key, attr.Value)) | |||||
| { | |||||
| attrs_to_strip.Add(attr.Key); | |||||
| } | |||||
| } | |||||
| foreach (var attr in attrs_to_strip) | |||||
| { | |||||
| node_def.Attr.Remove(attr); | |||||
| } | |||||
| }; | |||||
| foreach (var node_def in meta_graph_def.GraphDef.Node) | |||||
| { | |||||
| _strip_node_default_valued_attrs(node_def); | |||||
| } | |||||
| foreach (var function_def in meta_graph_def.GraphDef.Library.Function) | |||||
| { | |||||
| foreach (var function_node_def in function_def.NodeDef) | |||||
| { | |||||
| _strip_node_default_valued_attrs(function_node_def); | |||||
| } | |||||
| } | |||||
| meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true; | |||||
| } | |||||
| /// <summary> | |||||
| /// Extract the Op name from a Tensor name. | |||||
| /// </summary> | |||||
| /// <param name="tensor_name"></param> | |||||
| /// <returns></returns> | |||||
| public static string op_name(string tensor_name) | |||||
| { | |||||
| if (string.IsNullOrEmpty(tensor_name)) | |||||
| { | |||||
| throw new ValueError($"Tensor name cannot be empty or None. Received: {tensor_name}."); | |||||
| } | |||||
| if (tensor_name.StartsWith("^")) | |||||
| { | |||||
| tensor_name = tensor_name.Substring(1); | |||||
| } | |||||
| if (tensor_name.Contains(":")) | |||||
| { | |||||
| return tensor_name.Split(':')[0]; | |||||
| } | |||||
| return tensor_name; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Framework.Models; | using Tensorflow.Framework.Models; | ||||
| using Tensorflow.Graphs; | using Tensorflow.Graphs; | ||||
| using Tensorflow.Train; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Functions | namespace Tensorflow.Functions | ||||
| @@ -10,7 +11,7 @@ namespace Tensorflow.Functions | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| /// </summary> | /// </summary> | ||||
| public class ConcreteFunction | |||||
| public class ConcreteFunction: Trackable | |||||
| { | { | ||||
| FuncGraph func_graph; | FuncGraph func_graph; | ||||
| ForwardBackwardCall forward_backward; | ForwardBackwardCall forward_backward; | ||||
| @@ -1,16 +1,23 @@ | |||||
| using System; | using System; | ||||
| using Tensorflow.Train; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class Function | |||||
| public class Function: Trackable | |||||
| { | { | ||||
| #pragma warning disable CS0169 // The field 'Function._handle' is never used | #pragma warning disable CS0169 // The field 'Function._handle' is never used | ||||
| private IntPtr _handle; | private IntPtr _handle; | ||||
| #pragma warning restore CS0169 // The field 'Function._handle' is never used | #pragma warning restore CS0169 // The field 'Function._handle' is never used | ||||
| public string Name { get; set; } | |||||
| public Function() | public Function() | ||||
| { | { | ||||
| } | } | ||||
| public Function(string name) | |||||
| { | |||||
| Name = name; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,4 +1,5 @@ | |||||
| using System; | using System; | ||||
| using System.Diagnostics; | |||||
| using System.Linq; | using System.Linq; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -6,14 +7,14 @@ namespace Tensorflow.Graphs | |||||
| { | { | ||||
| public class AutoGraph | public class AutoGraph | ||||
| { | { | ||||
| public Func<Tensor, Tensor> to_graph(Func<Tensor, Tensor> func) | |||||
| public Func<Tensor, Tensor> to_graph(Func<Tensor, Tensor> func, TF_DataType dtype = TF_DataType.TF_INT32) | |||||
| { | { | ||||
| string func_name = $"{func.Method.Name}_{ops.uid_function()}"; | string func_name = $"{func.Method.Name}_{ops.uid_function()}"; | ||||
| var graph = new FuncGraph(func_name); | var graph = new FuncGraph(func_name); | ||||
| graph.as_default(); | graph.as_default(); | ||||
| var input = tf.placeholder(tf.int32); | |||||
| var input = tf.placeholder(dtype); | |||||
| var output = func(input); | var output = func(input); | ||||
| var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | ||||
| @@ -26,25 +27,33 @@ namespace Tensorflow.Graphs | |||||
| return (Tensor input) => | return (Tensor input) => | ||||
| { | { | ||||
| var result = tf.Runner.TFE_Execute(tf.Context, | |||||
| tf.Context.DeviceName, | |||||
| func_name, | |||||
| new[] { input }, | |||||
| null, | |||||
| 1); | |||||
| return result[0]; | |||||
| if (tf.executing_eagerly()) | |||||
| { | |||||
| var result = tf.Runner.TFE_Execute(tf.Context, | |||||
| tf.Context.DeviceName, | |||||
| func_name, | |||||
| new[] { input }, | |||||
| null, | |||||
| 1); | |||||
| return result[0]; | |||||
| } | |||||
| using (var s = tf.Session(input.graph)) | |||||
| { | |||||
| var output = func(input); | |||||
| return output; | |||||
| } | |||||
| }; | }; | ||||
| } | } | ||||
| public Func<Tensor, Tensor, Tensor> to_graph(Func<Tensor, Tensor, Tensor> func) | |||||
| public Func<Tensor, Tensor, Tensor> to_graph(Func<Tensor, Tensor, Tensor> func, params TF_DataType[] dtypes) | |||||
| { | { | ||||
| string func_name = $"{func.Method.Name}_{ops.uid_function()}"; | string func_name = $"{func.Method.Name}_{ops.uid_function()}"; | ||||
| var graph = new FuncGraph(func_name); | var graph = new FuncGraph(func_name); | ||||
| graph.as_default(); | graph.as_default(); | ||||
| var input1 = tf.placeholder(tf.int32); | |||||
| var input2 = tf.placeholder(tf.int32); | |||||
| var input1 = tf.placeholder(dtypes.Length >= 1 ? dtypes[0] : tf.int32); | |||||
| var input2 = tf.placeholder(dtypes.Length >= 2 ? dtypes[1] : tf.int32); | |||||
| var output = func(input1, input2); | var output = func(input1, input2); | ||||
| var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | ||||
| @@ -56,13 +65,22 @@ namespace Tensorflow.Graphs | |||||
| return (Tensor a, Tensor b) => | return (Tensor a, Tensor b) => | ||||
| { | { | ||||
| var result = tf.Runner.TFE_Execute(tf.Context, | |||||
| if (tf.executing_eagerly()) | |||||
| { | |||||
| var result = tf.Runner.TFE_Execute(tf.Context, | |||||
| tf.Context.DeviceName, | tf.Context.DeviceName, | ||||
| func_name, | func_name, | ||||
| new[] { a, b }, | new[] { a, b }, | ||||
| null, | null, | ||||
| 1); | 1); | ||||
| return result[0]; | |||||
| return result[0]; | |||||
| } | |||||
| using (var s = tf.Session(a.graph)) | |||||
| { | |||||
| Debug.Assert(a.graph == b.graph); | |||||
| var output = func(a, b); | |||||
| return output; | |||||
| } | |||||
| }; | }; | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,9 +1,12 @@ | |||||
| using System; | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.Keras.ArgsDefinition { | namespace Tensorflow.Keras.ArgsDefinition { | ||||
| public class ELUArgs : LayerArgs { | |||||
| public float Alpha { get; set; } = 0.1f; | |||||
| } | |||||
| public class ELUArgs : AutoSerializeLayerArgs | |||||
| { | |||||
| [JsonProperty("alpha")] | |||||
| public float Alpha { get; set; } = 0.1f; | |||||
| } | |||||
| } | } | ||||
| @@ -1,14 +1,16 @@ | |||||
| using System; | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
| { | { | ||||
| public class LeakyReLuArgs : LayerArgs | |||||
| public class LeakyReLuArgs : AutoSerializeLayerArgs | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Negative slope coefficient. | /// Negative slope coefficient. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("alpha")] | |||||
| public float Alpha { get; set; } = 0.3f; | public float Alpha { get; set; } = 0.3f; | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,9 +1,12 @@ | |||||
| using System; | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.Keras.ArgsDefinition { | namespace Tensorflow.Keras.ArgsDefinition { | ||||
| public class SoftmaxArgs : LayerArgs { | |||||
| public Axis axis { get; set; } = -1; | |||||
| } | |||||
| public class SoftmaxArgs : AutoSerializeLayerArgs | |||||
| { | |||||
| [JsonProperty("axis")] | |||||
| public Axis axis { get; set; } = -1; | |||||
| } | |||||
| } | } | ||||
| @@ -1,3 +1,5 @@ | |||||
| using Newtonsoft.Json; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
| { | { | ||||
| public class AttentionArgs : BaseDenseAttentionArgs | public class AttentionArgs : BaseDenseAttentionArgs | ||||
| @@ -6,6 +8,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
| /// <summary> | /// <summary> | ||||
| /// If `true`, will create a scalar variable to scale the attention scores. | /// If `true`, will create a scalar variable to scale the attention scores. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("use_scale")] | |||||
| public bool use_scale { get; set; } = false; | public bool use_scale { get; set; } = false; | ||||
| /// <summary> | /// <summary> | ||||
| @@ -14,6 +17,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
| /// and key vectors. `"concat"` refers to the hyperbolic tangent of the | /// and key vectors. `"concat"` refers to the hyperbolic tangent of the | ||||
| /// concatenation of the query and key vectors. | /// concatenation of the query and key vectors. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("score_mode")] | |||||
| public string score_mode { get; set; } = "dot"; | public string score_mode { get; set; } = "dot"; | ||||
| } | } | ||||
| @@ -1,6 +1,8 @@ | |||||
| using Newtonsoft.Json; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
| { | { | ||||
| public class BaseDenseAttentionArgs : LayerArgs | |||||
| public class BaseDenseAttentionArgs : AutoSerializeLayerArgs | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| @@ -14,6 +16,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
| /// Float between 0 and 1. Fraction of the units to drop for the | /// Float between 0 and 1. Fraction of the units to drop for the | ||||
| /// attention scores. | /// attention scores. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("dropout")] | |||||
| public float dropout { get; set; } = 0f; | public float dropout { get; set; } = 0f; | ||||
| } | } | ||||
| @@ -1,22 +1,40 @@ | |||||
| using Newtonsoft.Json; | |||||
| using System; | using System; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
| { | { | ||||
| public class MultiHeadAttentionArgs : LayerArgs | |||||
| public class MultiHeadAttentionArgs : AutoSerializeLayerArgs | |||||
| { | { | ||||
| [JsonProperty("num_heads")] | |||||
| public int NumHeads { get; set; } | public int NumHeads { get; set; } | ||||
| [JsonProperty("key_dim")] | |||||
| public int KeyDim { get; set; } | public int KeyDim { get; set; } | ||||
| [JsonProperty("value_dim")] | |||||
| public int? ValueDim { get; set; } = null; | public int? ValueDim { get; set; } = null; | ||||
| [JsonProperty("dropout")] | |||||
| public float Dropout { get; set; } = 0f; | public float Dropout { get; set; } = 0f; | ||||
| [JsonProperty("use_bias")] | |||||
| public bool UseBias { get; set; } = true; | public bool UseBias { get; set; } = true; | ||||
| [JsonProperty("output_shape")] | |||||
| public Shape OutputShape { get; set; } = null; | public Shape OutputShape { get; set; } = null; | ||||
| [JsonProperty("attention_axes")] | |||||
| public Shape AttentionAxis { get; set; } = null; | public Shape AttentionAxis { get; set; } = null; | ||||
| [JsonProperty("kernel_initializer")] | |||||
| public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; | public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; | ||||
| [JsonProperty("bias_initializer")] | |||||
| public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; | public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; | ||||
| [JsonProperty("kernel_regularizer")] | |||||
| public IRegularizer KernelRegularizer { get; set; } = null; | public IRegularizer KernelRegularizer { get; set; } = null; | ||||
| [JsonProperty("bias_regularizer")] | |||||
| public IRegularizer BiasRegularizer { get; set; } = null; | public IRegularizer BiasRegularizer { get; set; } = null; | ||||
| [JsonProperty("kernel_constraint")] | |||||
| public Action KernelConstraint { get; set; } = null; | public Action KernelConstraint { get; set; } = null; | ||||
| [JsonProperty("bias_constraint")] | |||||
| public Action BiasConstraint { get; set; } = null; | public Action BiasConstraint { get; set; } = null; | ||||
| [JsonProperty("activity_regularizer")] | |||||
| public override IRegularizer ActivityRegularizer { get => base.ActivityRegularizer; set => base.ActivityRegularizer = value; } | |||||
| // TODO: Add `key_shape`, `value_shape`, `query_shape`. | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,25 @@ | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| /// <summary> | |||||
| /// This class has nothing but the attributes different from `LayerArgs`. | |||||
| /// It's used to serialize the model to `tf` format. | |||||
| /// If the `get_config` of a `Layer` in python code of tensorflow contains `super().get_config`, | |||||
| /// then the Arg definition should inherit `utoSerializeLayerArgs` instead of `LayerArgs`. | |||||
| /// </summary> | |||||
| public class AutoSerializeLayerArgs: LayerArgs | |||||
| { | |||||
| [JsonProperty("name")] | |||||
| public override string Name { get => base.Name; set => base.Name = value; } | |||||
| [JsonProperty("dtype")] | |||||
| public override TF_DataType DType { get => base.DType; set => base.DType = value; } | |||||
| [JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)] | |||||
| public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; } | |||||
| [JsonProperty("trainable")] | |||||
| public override bool Trainable { get => base.Trainable; set => base.Trainable = value; } | |||||
| } | |||||
| } | |||||
| @@ -1,31 +1,65 @@ | |||||
| using System; | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
| { | { | ||||
| public class ConvolutionalArgs : LayerArgs | |||||
| public class ConvolutionalArgs : AutoSerializeLayerArgs | |||||
| { | { | ||||
| public int Rank { get; set; } = 2; | public int Rank { get; set; } = 2; | ||||
| [JsonProperty("filters")] | |||||
| public int Filters { get; set; } | public int Filters { get; set; } | ||||
| public int NumSpatialDims { get; set; } = Unknown; | public int NumSpatialDims { get; set; } = Unknown; | ||||
| [JsonProperty("kernel_size")] | |||||
| public Shape KernelSize { get; set; } = 5; | public Shape KernelSize { get; set; } = 5; | ||||
| /// <summary> | /// <summary> | ||||
| /// specifying the stride length of the convolution. | /// specifying the stride length of the convolution. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("strides")] | |||||
| public Shape Strides { get; set; } = (1, 1); | public Shape Strides { get; set; } = (1, 1); | ||||
| [JsonProperty("padding")] | |||||
| public string Padding { get; set; } = "valid"; | public string Padding { get; set; } = "valid"; | ||||
| [JsonProperty("data_format")] | |||||
| public string DataFormat { get; set; } | public string DataFormat { get; set; } | ||||
| [JsonProperty("dilation_rate")] | |||||
| public Shape DilationRate { get; set; } = (1, 1); | public Shape DilationRate { get; set; } = (1, 1); | ||||
| [JsonProperty("groups")] | |||||
| public int Groups { get; set; } = 1; | public int Groups { get; set; } = 1; | ||||
| public Activation Activation { get; set; } | public Activation Activation { get; set; } | ||||
| private string _activationName; | |||||
| [JsonProperty("activation")] | |||||
| public string ActivationName | |||||
| { | |||||
| get | |||||
| { | |||||
| if (string.IsNullOrEmpty(_activationName)) | |||||
| { | |||||
| return Activation.Method.Name; | |||||
| } | |||||
| else | |||||
| { | |||||
| return _activationName; | |||||
| } | |||||
| } | |||||
| set | |||||
| { | |||||
| _activationName = value; | |||||
| } | |||||
| } | |||||
| [JsonProperty("use_bias")] | |||||
| public bool UseBias { get; set; } | public bool UseBias { get; set; } | ||||
| [JsonProperty("kernel_initializer")] | |||||
| public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; | public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; | ||||
| [JsonProperty("bias_initializer")] | |||||
| public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; | public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; | ||||
| [JsonProperty("kernel_regularizer")] | |||||
| public IRegularizer KernelRegularizer { get; set; } | public IRegularizer KernelRegularizer { get; set; } | ||||
| [JsonProperty("bias_regularizer")] | |||||
| public IRegularizer BiasRegularizer { get; set; } | public IRegularizer BiasRegularizer { get; set; } | ||||
| [JsonProperty("kernel_constraint")] | |||||
| public Action KernelConstraint { get; set; } | public Action KernelConstraint { get; set; } | ||||
| [JsonProperty("bias_constraint")] | |||||
| public Action BiasConstraint { get; set; } | public Action BiasConstraint { get; set; } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,13 +1,18 @@ | |||||
| using System; | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Xml.Linq; | |||||
| using Tensorflow.Operations.Initializers; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
| { | { | ||||
| // TODO: `activity_regularizer` | |||||
| public class DenseArgs : LayerArgs | public class DenseArgs : LayerArgs | ||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Positive integer, dimensionality of the output space. | /// Positive integer, dimensionality of the output space. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("units")] | |||||
| public int Units { get; set; } | public int Units { get; set; } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -15,39 +20,74 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
| /// </summary> | /// </summary> | ||||
| public Activation Activation { get; set; } | public Activation Activation { get; set; } | ||||
| private string _activationName; | |||||
| [JsonProperty("activation")] | |||||
| public string ActivationName | |||||
| { | |||||
| get | |||||
| { | |||||
| if (string.IsNullOrEmpty(_activationName)) | |||||
| { | |||||
| return Activation.Method.Name; | |||||
| } | |||||
| else | |||||
| { | |||||
| return _activationName; | |||||
| } | |||||
| } | |||||
| set | |||||
| { | |||||
| _activationName = value; | |||||
| } | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Whether the layer uses a bias vector. | /// Whether the layer uses a bias vector. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("use_bias")] | |||||
| public bool UseBias { get; set; } = true; | public bool UseBias { get; set; } = true; | ||||
| /// <summary> | /// <summary> | ||||
| /// Initializer for the `kernel` weights matrix. | /// Initializer for the `kernel` weights matrix. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("kernel_initializer")] | |||||
| public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; | public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; | ||||
| /// <summary> | /// <summary> | ||||
| /// Initializer for the bias vector. | /// Initializer for the bias vector. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("bias_initializer")] | |||||
| public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; | public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; | ||||
| /// <summary> | /// <summary> | ||||
| /// Regularizer function applied to the `kernel` weights matrix. | /// Regularizer function applied to the `kernel` weights matrix. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("kernel_regularizer")] | |||||
| public IRegularizer KernelRegularizer { get; set; } | public IRegularizer KernelRegularizer { get; set; } | ||||
| /// <summary> | /// <summary> | ||||
| /// Regularizer function applied to the bias vector. | /// Regularizer function applied to the bias vector. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("bias_regularizer")] | |||||
| public IRegularizer BiasRegularizer { get; set; } | public IRegularizer BiasRegularizer { get; set; } | ||||
| /// <summary> | /// <summary> | ||||
| /// Constraint function applied to the `kernel` weights matrix. | /// Constraint function applied to the `kernel` weights matrix. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("kernel_constraint")] | |||||
| public Action KernelConstraint { get; set; } | public Action KernelConstraint { get; set; } | ||||
| /// <summary> | /// <summary> | ||||
| /// Constraint function applied to the bias vector. | /// Constraint function applied to the bias vector. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("bias_constraint")] | |||||
| public Action BiasConstraint { get; set; } | public Action BiasConstraint { get; set; } | ||||
| [JsonProperty("name")] | |||||
| public override string Name { get => base.Name; set => base.Name = value; } | |||||
| [JsonProperty("dtype")] | |||||
| public override TF_DataType DType { get => base.DType; set => base.DType = value; } | |||||
| [JsonProperty("trainable")] | |||||
| public override bool Trainable { get => base.Trainable; set => base.Trainable = value; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,9 +1,10 @@ | |||||
| using Newtonsoft.Json; | |||||
| using System; | using System; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| namespace Tensorflow.Keras.ArgsDefinition.Core | |||||
| { | { | ||||
| public class EinsumDenseArgs : LayerArgs | |||||
| public class EinsumDenseArgs : AutoSerializeLayerArgs | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// An equation describing the einsum to perform. This equation must | /// An equation describing the einsum to perform. This equation must | ||||
| @@ -11,6 +12,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
| /// `ab...,bc->ac...` where 'ab', 'bc', and 'ac' can be any valid einsum axis | /// `ab...,bc->ac...` where 'ab', 'bc', and 'ac' can be any valid einsum axis | ||||
| /// expression sequence. | /// expression sequence. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("equation")] | |||||
| public string Equation { get; set; } | public string Equation { get; set; } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -19,6 +21,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
| /// None for any dimension that is unknown or can be inferred from the input | /// None for any dimension that is unknown or can be inferred from the input | ||||
| /// shape. | /// shape. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("output_shape")] | |||||
| public Shape OutputShape { get; set; } | public Shape OutputShape { get; set; } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -26,41 +29,70 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
| /// Each character in the `bias_axes` string should correspond to a character | /// Each character in the `bias_axes` string should correspond to a character | ||||
| /// in the output portion of the `equation` string. | /// in the output portion of the `equation` string. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("bias_axes")] | |||||
| public string BiasAxes { get; set; } = null; | public string BiasAxes { get; set; } = null; | ||||
| /// <summary> | /// <summary> | ||||
| /// Activation function to use. | /// Activation function to use. | ||||
| /// </summary> | /// </summary> | ||||
| public Activation Activation { get; set; } | public Activation Activation { get; set; } | ||||
| private string _activationName; | |||||
| [JsonProperty("activation")] | |||||
| public string ActivationName | |||||
| { | |||||
| get | |||||
| { | |||||
| if (string.IsNullOrEmpty(_activationName)) | |||||
| { | |||||
| return Activation.Method.Name; | |||||
| } | |||||
| else | |||||
| { | |||||
| return _activationName; | |||||
| } | |||||
| } | |||||
| set | |||||
| { | |||||
| _activationName = value; | |||||
| } | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Initializer for the `kernel` weights matrix. | /// Initializer for the `kernel` weights matrix. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("kernel_initializer")] | |||||
| public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; | public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; | ||||
| /// <summary> | /// <summary> | ||||
| /// Initializer for the bias vector. | /// Initializer for the bias vector. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("bias_initializer")] | |||||
| public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; | public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; | ||||
| /// <summary> | /// <summary> | ||||
| /// Regularizer function applied to the `kernel` weights matrix. | /// Regularizer function applied to the `kernel` weights matrix. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("kernel_regularizer")] | |||||
| public IRegularizer KernelRegularizer { get; set; } | public IRegularizer KernelRegularizer { get; set; } | ||||
| /// <summary> | /// <summary> | ||||
| /// Regularizer function applied to the bias vector. | /// Regularizer function applied to the bias vector. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("bias_regularizer")] | |||||
| public IRegularizer BiasRegularizer { get; set; } | public IRegularizer BiasRegularizer { get; set; } | ||||
| /// <summary> | /// <summary> | ||||
| /// Constraint function applied to the `kernel` weights matrix. | /// Constraint function applied to the `kernel` weights matrix. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("kernel_constraint")] | |||||
| public Action KernelConstraint { get; set; } | public Action KernelConstraint { get; set; } | ||||
| /// <summary> | /// <summary> | ||||
| /// Constraint function applied to the bias vector. | /// Constraint function applied to the bias vector. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("bias_constraint")] | |||||
| public Action BiasConstraint { get; set; } | public Action BiasConstraint { get; set; } | ||||
| [JsonProperty("activity_regularizer")] | |||||
| public override IRegularizer ActivityRegularizer { get => base.ActivityRegularizer; set => base.ActivityRegularizer = value; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,11 +1,22 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| using Newtonsoft.Json; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | { | ||||
| public class EmbeddingArgs : LayerArgs | |||||
| public class EmbeddingArgs : AutoSerializeLayerArgs | |||||
| { | { | ||||
| [JsonProperty("input_dim")] | |||||
| public int InputDim { get; set; } | public int InputDim { get; set; } | ||||
| [JsonProperty("output_dim")] | |||||
| public int OutputDim { get; set; } | public int OutputDim { get; set; } | ||||
| [JsonProperty("mask_zero")] | |||||
| public bool MaskZero { get; set; } | public bool MaskZero { get; set; } | ||||
| [JsonProperty("input_length")] | |||||
| public int InputLength { get; set; } = -1; | public int InputLength { get; set; } = -1; | ||||
| [JsonProperty("embeddings_initializer")] | |||||
| public IInitializer EmbeddingsInitializer { get; set; } | public IInitializer EmbeddingsInitializer { get; set; } | ||||
| [JsonProperty("activity_regularizer")] | |||||
| public override IRegularizer ActivityRegularizer { get => base.ActivityRegularizer; set => base.ActivityRegularizer = value; } | |||||
| // TODO: `embeddings_regularizer`, `embeddings_constraint`. | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,9 +1,22 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| using Newtonsoft.Json; | |||||
| using Newtonsoft.Json.Serialization; | |||||
| using Tensorflow.Keras.Common; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | { | ||||
| public class InputLayerArgs : LayerArgs | public class InputLayerArgs : LayerArgs | ||||
| { | { | ||||
| [JsonIgnore] | |||||
| public Tensor InputTensor { get; set; } | public Tensor InputTensor { get; set; } | ||||
| public bool Sparse { get; set; } | |||||
| [JsonProperty("sparse")] | |||||
| public virtual bool Sparse { get; set; } | |||||
| [JsonProperty("ragged")] | |||||
| public bool Ragged { get; set; } | public bool Ragged { get; set; } | ||||
| [JsonProperty("name")] | |||||
| public override string Name { get => base.Name; set => base.Name = value; } | |||||
| [JsonProperty("dtype")] | |||||
| public override TF_DataType DType { get => base.DType; set => base.DType = value; } | |||||
| [JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)] | |||||
| public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,16 +0,0 @@ | |||||
| using Tensorflow.NumPy; | |||||
| namespace Tensorflow.Keras.ArgsDefinition { | |||||
| public class Cropping2DArgs : LayerArgs { | |||||
| /// <summary> | |||||
| /// channel last: (b, h, w, c) | |||||
| /// channels_first: (b, c, h, w) | |||||
| /// </summary> | |||||
| public enum DataFormat { channels_first = 0, channels_last = 1 } | |||||
| /// <summary> | |||||
| /// Accept: int[1][2], int[1][1], int[2][2] | |||||
| /// </summary> | |||||
| public NDArray cropping { get; set; } | |||||
| public DataFormat data_format { get; set; } = DataFormat.channels_last; | |||||
| } | |||||
| } | |||||
| @@ -1,16 +0,0 @@ | |||||
| using Tensorflow.NumPy; | |||||
| namespace Tensorflow.Keras.ArgsDefinition { | |||||
| public class Cropping3DArgs : LayerArgs { | |||||
| /// <summary> | |||||
| /// channel last: (b, h, w, c) | |||||
| /// channels_first: (b, c, h, w) | |||||
| /// </summary> | |||||
| public enum DataFormat { channels_first = 0, channels_last = 1 } | |||||
| /// <summary> | |||||
| /// Accept: int[1][3], int[1][1], int[3][2] | |||||
| /// </summary> | |||||
| public NDArray cropping { get; set; } | |||||
| public DataFormat data_format { get; set; } = DataFormat.channels_last; | |||||
| } | |||||
| } | |||||
| @@ -1,10 +0,0 @@ | |||||
| using Tensorflow.NumPy; | |||||
| namespace Tensorflow.Keras.ArgsDefinition { | |||||
| public class CroppingArgs : LayerArgs { | |||||
| /// <summary> | |||||
| /// Accept length 1 or 2 | |||||
| /// </summary> | |||||
| public NDArray cropping { get; set; } | |||||
| } | |||||
| } | |||||
| @@ -1,8 +1,9 @@ | |||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
| { | { | ||||
| public class DataAdapterArgs | |||||
| public class DataAdapterArgs: IKerasConfig | |||||
| { | { | ||||
| public Tensor X { get; set; } | public Tensor X { get; set; } | ||||
| public Tensor Y { get; set; } | public Tensor Y { get; set; } | ||||
| @@ -1,8 +1,9 @@ | |||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
| { | { | ||||
| public class DataHandlerArgs | |||||
| public class DataHandlerArgs: IKerasConfig | |||||
| { | { | ||||
| public Tensor X { get; set; } | public Tensor X { get; set; } | ||||
| public Tensor Y { get; set; } | public Tensor Y { get; set; } | ||||
| @@ -1,51 +1,54 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| using Newtonsoft.Json; | |||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | { | ||||
| public class LayerArgs | |||||
| [JsonObject(MemberSerialization.OptIn)] | |||||
| public class LayerArgs: IKerasConfig | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Indicates whether the layer's weights are updated during training | /// Indicates whether the layer's weights are updated during training | ||||
| /// and whether the layer's updates are run during training. | /// and whether the layer's updates are run during training. | ||||
| /// </summary> | /// </summary> | ||||
| public bool Trainable { get; set; } = true; | |||||
| public string Name { get; set; } | |||||
| public virtual bool Trainable { get; set; } = true; | |||||
| public virtual string Name { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// Only applicable to input layers. | /// Only applicable to input layers. | ||||
| /// </summary> | /// </summary> | ||||
| public TF_DataType DType { get; set; } = TF_DataType.TF_FLOAT; | |||||
| public virtual TF_DataType DType { get; set; } = TF_DataType.TF_FLOAT; | |||||
| /// <summary> | /// <summary> | ||||
| /// Whether the `call` method can be used to build a TF graph without issues. | /// Whether the `call` method can be used to build a TF graph without issues. | ||||
| /// This attribute has no effect if the model is created using the Functional | /// This attribute has no effect if the model is created using the Functional | ||||
| /// API. Instead, `model.dynamic` is determined based on the internal layers. | /// API. Instead, `model.dynamic` is determined based on the internal layers. | ||||
| /// </summary> | /// </summary> | ||||
| public bool Dynamic { get; set; } = false; | |||||
| public virtual bool Dynamic { get; set; } = false; | |||||
| /// <summary> | /// <summary> | ||||
| /// Only applicable to input layers. | /// Only applicable to input layers. | ||||
| /// </summary> | /// </summary> | ||||
| public Shape InputShape { get; set; } | |||||
| public virtual Shape InputShape { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// Only applicable to input layers. | /// Only applicable to input layers. | ||||
| /// </summary> | /// </summary> | ||||
| public Shape BatchInputShape { get; set; } | |||||
| public virtual Shape BatchInputShape { get; set; } | |||||
| public int BatchSize { get; set; } = -1; | |||||
| public virtual int BatchSize { get; set; } = -1; | |||||
| /// <summary> | /// <summary> | ||||
| /// Initial weight values. | /// Initial weight values. | ||||
| /// </summary> | /// </summary> | ||||
| public float[] Weights { get; set; } | |||||
| public virtual float[] Weights { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// Regularizer function applied to the output of the layer(its "activation"). | /// Regularizer function applied to the output of the layer(its "activation"). | ||||
| /// </summary> | /// </summary> | ||||
| public IRegularizer ActivityRegularizer { get; set; } | |||||
| public virtual IRegularizer ActivityRegularizer { get; set; } | |||||
| public bool Autocast { get; set; } | |||||
| public virtual bool Autocast { get; set; } | |||||
| public bool IsFromConfig { get; set; } | |||||
| public virtual bool IsFromConfig { get; set; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,6 +0,0 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition.Lstm | |||||
| { | |||||
| public class LSTMCellArgs : LayerArgs | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -4,6 +4,7 @@ using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
| { | { | ||||
| // TODO: complete the implementation | |||||
| public class MergeArgs : LayerArgs | public class MergeArgs : LayerArgs | ||||
| { | { | ||||
| public Tensors Inputs { get; set; } | public Tensors Inputs { get; set; } | ||||
| @@ -1,6 +1,8 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | { | ||||
| public class NodeArgs | |||||
| public class NodeArgs: IKerasConfig | |||||
| { | { | ||||
| public ILayer[] InboundLayers { get; set; } | public ILayer[] InboundLayers { get; set; } | ||||
| public int[] NodeIndices { get; set; } | public int[] NodeIndices { get; set; } | ||||
| @@ -1,21 +1,37 @@ | |||||
| using static Tensorflow.Binding; | |||||
| using Newtonsoft.Json; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
| { | { | ||||
| public class BatchNormalizationArgs : LayerArgs | |||||
| public class BatchNormalizationArgs : AutoSerializeLayerArgs | |||||
| { | { | ||||
| [JsonProperty("axis")] | |||||
| public Shape Axis { get; set; } = -1; | public Shape Axis { get; set; } = -1; | ||||
| [JsonProperty("momentum")] | |||||
| public float Momentum { get; set; } = 0.99f; | public float Momentum { get; set; } = 0.99f; | ||||
| [JsonProperty("epsilon")] | |||||
| public float Epsilon { get; set; } = 1e-3f; | public float Epsilon { get; set; } = 1e-3f; | ||||
| [JsonProperty("center")] | |||||
| public bool Center { get; set; } = true; | public bool Center { get; set; } = true; | ||||
| [JsonProperty("scale")] | |||||
| public bool Scale { get; set; } = true; | public bool Scale { get; set; } = true; | ||||
| [JsonProperty("beta_initializer")] | |||||
| public IInitializer BetaInitializer { get; set; } = tf.zeros_initializer; | public IInitializer BetaInitializer { get; set; } = tf.zeros_initializer; | ||||
| [JsonProperty("gamma_initializer")] | |||||
| public IInitializer GammaInitializer { get; set; } = tf.ones_initializer; | public IInitializer GammaInitializer { get; set; } = tf.ones_initializer; | ||||
| [JsonProperty("moving_mean_initializer")] | |||||
| public IInitializer MovingMeanInitializer { get; set; } = tf.zeros_initializer; | public IInitializer MovingMeanInitializer { get; set; } = tf.zeros_initializer; | ||||
| [JsonProperty("moving_variance_initializer")] | |||||
| public IInitializer MovingVarianceInitializer { get; set; } = tf.ones_initializer; | public IInitializer MovingVarianceInitializer { get; set; } = tf.ones_initializer; | ||||
| [JsonProperty("beta_regularizer")] | |||||
| public IRegularizer BetaRegularizer { get; set; } | public IRegularizer BetaRegularizer { get; set; } | ||||
| [JsonProperty("gamma_regularizer")] | |||||
| public IRegularizer GammaRegularizer { get; set; } | public IRegularizer GammaRegularizer { get; set; } | ||||
| // TODO: `beta_constraint` and `gamma_constraint`. | |||||
| [JsonProperty("renorm")] | |||||
| public bool Renorm { get; set; } | public bool Renorm { get; set; } | ||||
| // TODO: `renorm_clipping` and `virtual_batch_size`. | |||||
| [JsonProperty("renorm_momentum")] | |||||
| public float RenormMomentum { get; set; } = 0.99f; | public float RenormMomentum { get; set; } = 0.99f; | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,16 +1,27 @@ | |||||
| using static Tensorflow.Binding; | |||||
| using Newtonsoft.Json; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
| { | { | ||||
| public class LayerNormalizationArgs : LayerArgs | |||||
| public class LayerNormalizationArgs : AutoSerializeLayerArgs | |||||
| { | { | ||||
| [JsonProperty("axis")] | |||||
| public Axis Axis { get; set; } = -1; | public Axis Axis { get; set; } = -1; | ||||
| [JsonProperty("epsilon")] | |||||
| public float Epsilon { get; set; } = 1e-3f; | public float Epsilon { get; set; } = 1e-3f; | ||||
| [JsonProperty("center")] | |||||
| public bool Center { get; set; } = true; | public bool Center { get; set; } = true; | ||||
| [JsonProperty("scale")] | |||||
| public bool Scale { get; set; } = true; | public bool Scale { get; set; } = true; | ||||
| [JsonProperty("beta_initializer")] | |||||
| public IInitializer BetaInitializer { get; set; } = tf.zeros_initializer; | public IInitializer BetaInitializer { get; set; } = tf.zeros_initializer; | ||||
| [JsonProperty("gamma_initializer")] | |||||
| public IInitializer GammaInitializer { get; set; } = tf.ones_initializer; | public IInitializer GammaInitializer { get; set; } = tf.ones_initializer; | ||||
| [JsonProperty("beta_regularizer")] | |||||
| public IRegularizer BetaRegularizer { get; set; } | public IRegularizer BetaRegularizer { get; set; } | ||||
| [JsonProperty("gamma_regularizer")] | |||||
| public IRegularizer GammaRegularizer { get; set; } | public IRegularizer GammaRegularizer { get; set; } | ||||
| // TODO: `beta_constraint` and `gamma_constraint`. | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,6 +1,8 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | { | ||||
| public class OptimizerV2Args | |||||
| public class OptimizerV2Args: IKerasConfig | |||||
| { | { | ||||
| public string Name { get; set; } | public string Name { get; set; } | ||||
| public float LearningRate { get; set; } = 0.001f; | public float LearningRate { get; set; } = 0.001f; | ||||
| @@ -1,6 +1,8 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| using Newtonsoft.Json; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | { | ||||
| public class Pooling1DArgs : LayerArgs | |||||
| public class Pooling1DArgs : AutoSerializeLayerArgs | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// The pooling function to apply, e.g. `tf.nn.max_pool2d`. | /// The pooling function to apply, e.g. `tf.nn.max_pool2d`. | ||||
| @@ -10,11 +12,13 @@ | |||||
| /// <summary> | /// <summary> | ||||
| /// specifying the size of the pooling window. | /// specifying the size of the pooling window. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("pool_size")] | |||||
| public int PoolSize { get; set; } | public int PoolSize { get; set; } | ||||
| /// <summary> | /// <summary> | ||||
| /// specifying the strides of the pooling operation. | /// specifying the strides of the pooling operation. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("strides")] | |||||
| public int Strides { | public int Strides { | ||||
| get { return _strides.HasValue ? _strides.Value : PoolSize; } | get { return _strides.HasValue ? _strides.Value : PoolSize; } | ||||
| set { _strides = value; } | set { _strides = value; } | ||||
| @@ -24,11 +28,13 @@ | |||||
| /// <summary> | /// <summary> | ||||
| /// The padding method, either 'valid' or 'same'. | /// The padding method, either 'valid' or 'same'. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("padding")] | |||||
| public string Padding { get; set; } = "valid"; | public string Padding { get; set; } = "valid"; | ||||
| /// <summary> | /// <summary> | ||||
| /// one of `channels_last` (default) or `channels_first`. | /// one of `channels_last` (default) or `channels_first`. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("data_format")] | |||||
| public string DataFormat { get; set; } | public string DataFormat { get; set; } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,6 +1,8 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| using Newtonsoft.Json; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | { | ||||
| public class Pooling2DArgs : LayerArgs | |||||
| public class Pooling2DArgs : AutoSerializeLayerArgs | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// The pooling function to apply, e.g. `tf.nn.max_pool2d`. | /// The pooling function to apply, e.g. `tf.nn.max_pool2d`. | ||||
| @@ -10,21 +12,25 @@ | |||||
| /// <summary> | /// <summary> | ||||
| /// specifying the size of the pooling window. | /// specifying the size of the pooling window. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("pool_size")] | |||||
| public Shape PoolSize { get; set; } | public Shape PoolSize { get; set; } | ||||
| /// <summary> | /// <summary> | ||||
| /// specifying the strides of the pooling operation. | /// specifying the strides of the pooling operation. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("strides")] | |||||
| public Shape Strides { get; set; } | public Shape Strides { get; set; } | ||||
| /// <summary> | /// <summary> | ||||
| /// The padding method, either 'valid' or 'same'. | /// The padding method, either 'valid' or 'same'. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("padding")] | |||||
| public string Padding { get; set; } = "valid"; | public string Padding { get; set; } = "valid"; | ||||
| /// <summary> | /// <summary> | ||||
| /// one of `channels_last` (default) or `channels_first`. | /// one of `channels_last` (default) or `channels_first`. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("data_format")] | |||||
| public string DataFormat { get; set; } | public string DataFormat { get; set; } | ||||
| } | } | ||||
| } | } | ||||
| @@ -4,7 +4,7 @@ using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
| { | { | ||||
| public class PreprocessingLayerArgs : LayerArgs | |||||
| public class PreprocessingLayerArgs : AutoSerializeLayerArgs | |||||
| { | { | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,12 @@ | |||||
| using Newtonsoft.Json; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class RescalingArgs : AutoSerializeLayerArgs | |||||
| { | |||||
| [JsonProperty("scale")] | |||||
| public float Scale { get; set; } | |||||
| [JsonProperty("offset")] | |||||
| public float Offset { get; set; } | |||||
| } | |||||
| } | |||||
| @@ -1,5 +1,6 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
| { | { | ||||
| // TODO: no corresponding class found in keras python, maybe obselete? | |||||
| public class ResizingArgs : PreprocessingLayerArgs | public class ResizingArgs : PreprocessingLayerArgs | ||||
| { | { | ||||
| public int Height { get; set; } | public int Height { get; set; } | ||||
| @@ -1,4 +1,5 @@ | |||||
| using System; | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| @@ -6,11 +7,19 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | { | ||||
| public class TextVectorizationArgs : PreprocessingLayerArgs | public class TextVectorizationArgs : PreprocessingLayerArgs | ||||
| { | { | ||||
| [JsonProperty("standardize")] | |||||
| public Func<Tensor, Tensor> Standardize { get; set; } | public Func<Tensor, Tensor> Standardize { get; set; } | ||||
| [JsonProperty("split")] | |||||
| public string Split { get; set; } = "standardize"; | public string Split { get; set; } = "standardize"; | ||||
| [JsonProperty("max_tokens")] | |||||
| public int MaxTokens { get; set; } = -1; | public int MaxTokens { get; set; } = -1; | ||||
| [JsonProperty("output_mode")] | |||||
| public string OutputMode { get; set; } = "int"; | public string OutputMode { get; set; } = "int"; | ||||
| [JsonProperty("output_sequence_length")] | |||||
| public int OutputSequenceLength { get; set; } = -1; | public int OutputSequenceLength { get; set; } = -1; | ||||
| [JsonProperty("vocabulary")] | |||||
| public string[] Vocabulary { get; set; } | public string[] Vocabulary { get; set; } | ||||
| // TODO: Add `ngrams`, `sparse`, `ragged`, `idf_weights`, `encoding` | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,21 +1,26 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| using Newtonsoft.Json; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | { | ||||
| public class DropoutArgs : LayerArgs | |||||
| public class DropoutArgs : AutoSerializeLayerArgs | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Float between 0 and 1. Fraction of the input units to drop. | /// Float between 0 and 1. Fraction of the input units to drop. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("rate")] | |||||
| public float Rate { get; set; } | public float Rate { get; set; } | ||||
| /// <summary> | /// <summary> | ||||
| /// 1D integer tensor representing the shape of the | /// 1D integer tensor representing the shape of the | ||||
| /// binary dropout mask that will be multiplied with the input. | /// binary dropout mask that will be multiplied with the input. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("noise_shape")] | |||||
| public Shape NoiseShape { get; set; } | public Shape NoiseShape { get; set; } | ||||
| /// <summary> | /// <summary> | ||||
| /// random seed. | /// random seed. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("seed")] | |||||
| public int? Seed { get; set; } | public int? Seed { get; set; } | ||||
| public bool SupportsMasking { get; set; } | public bool SupportsMasking { get; set; } | ||||
| @@ -1,8 +0,0 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class RescalingArgs : LayerArgs | |||||
| { | |||||
| public float Scale { get; set; } | |||||
| public float Offset { get; set; } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,18 @@ | |||||
| using Tensorflow.NumPy; | |||||
| namespace Tensorflow.Keras.ArgsDefinition.Reshaping | |||||
| { | |||||
| public class Cropping2DArgs : LayerArgs | |||||
| { | |||||
| /// <summary> | |||||
| /// channel last: (b, h, w, c) | |||||
| /// channels_first: (b, c, h, w) | |||||
| /// </summary> | |||||
| public enum DataFormat { channels_first = 0, channels_last = 1 } | |||||
| /// <summary> | |||||
| /// Accept: int[1][2], int[1][1], int[2][2] | |||||
| /// </summary> | |||||
| public NDArray cropping { get; set; } | |||||
| public DataFormat data_format { get; set; } = DataFormat.channels_last; | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,18 @@ | |||||
| using Tensorflow.NumPy; | |||||
| namespace Tensorflow.Keras.ArgsDefinition.Reshaping | |||||
| { | |||||
| public class Cropping3DArgs : LayerArgs | |||||
| { | |||||
| /// <summary> | |||||
| /// channel last: (b, h, w, c) | |||||
| /// channels_first: (b, c, h, w) | |||||
| /// </summary> | |||||
| public enum DataFormat { channels_first = 0, channels_last = 1 } | |||||
| /// <summary> | |||||
| /// Accept: int[1][3], int[1][1], int[3][2] | |||||
| /// </summary> | |||||
| public NDArray cropping { get; set; } | |||||
| public DataFormat data_format { get; set; } = DataFormat.channels_last; | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,12 @@ | |||||
| using Tensorflow.NumPy; | |||||
| namespace Tensorflow.Keras.ArgsDefinition.Reshaping | |||||
| { | |||||
| public class Cropping1DArgs : LayerArgs | |||||
| { | |||||
| /// <summary> | |||||
| /// Accept length 1 or 2 | |||||
| /// </summary> | |||||
| public NDArray cropping { get; set; } | |||||
| } | |||||
| } | |||||
| @@ -1,7 +1,10 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| using Newtonsoft.Json; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | { | ||||
| public class FlattenArgs : LayerArgs | |||||
| public class FlattenArgs : AutoSerializeLayerArgs | |||||
| { | { | ||||
| [JsonProperty("data_format")] | |||||
| public string DataFormat { get; set; } | public string DataFormat { get; set; } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,5 +1,9 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition { | |||||
| public class PermuteArgs : LayerArgs { | |||||
| public int[] dims { get; set; } | |||||
| } | |||||
| using Newtonsoft.Json; | |||||
| namespace Tensorflow.Keras.ArgsDefinition { | |||||
| public class PermuteArgs : AutoSerializeLayerArgs | |||||
| { | |||||
| [JsonProperty("dims")] | |||||
| public int[] dims { get; set; } | |||||
| } | |||||
| } | } | ||||
| @@ -1,7 +1,10 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| using Newtonsoft.Json; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | { | ||||
| public class ReshapeArgs : LayerArgs | |||||
| public class ReshapeArgs : AutoSerializeLayerArgs | |||||
| { | { | ||||
| [JsonProperty("target_shape")] | |||||
| public Shape TargetShape { get; set; } | public Shape TargetShape { get; set; } | ||||
| public object[] TargetShapeObjects { get; set; } | public object[] TargetShapeObjects { get; set; } | ||||
| } | } | ||||
| @@ -1,12 +1,17 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| using Newtonsoft.Json; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | { | ||||
| public class UpSampling2DArgs : LayerArgs | |||||
| public class UpSampling2DArgs : AutoSerializeLayerArgs | |||||
| { | { | ||||
| [JsonProperty("size")] | |||||
| public Shape Size { get; set; } | public Shape Size { get; set; } | ||||
| [JsonProperty("data_format")] | |||||
| public string DataFormat { get; set; } | public string DataFormat { get; set; } | ||||
| /// <summary> | /// <summary> | ||||
| /// 'nearest', 'bilinear' | /// 'nearest', 'bilinear' | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("interpolation")] | |||||
| public string Interpolation { get; set; } = "nearest"; | public string Interpolation { get; set; } = "nearest"; | ||||
| } | } | ||||
| } | } | ||||
| @@ -2,6 +2,7 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
| { | { | ||||
| // TODO: complete the implementation | |||||
| public class ZeroPadding2DArgs : LayerArgs | public class ZeroPadding2DArgs : LayerArgs | ||||
| { | { | ||||
| public NDArray Padding { get; set; } | public NDArray Padding { get; set; } | ||||
| @@ -1,9 +1,8 @@ | |||||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | |||||
| namespace Tensorflow.Keras.ArgsDefinition.Lstm | |||||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
| { | { | ||||
| public class LSTMArgs : RNNArgs | public class LSTMArgs : RNNArgs | ||||
| { | { | ||||
| // TODO: maybe change the `RNNArgs` and implement this class. | |||||
| public bool UnitForgetBias { get; set; } | public bool UnitForgetBias { get; set; } | ||||
| public float Dropout { get; set; } | public float Dropout { get; set; } | ||||
| public float RecurrentDropout { get; set; } | public float RecurrentDropout { get; set; } | ||||
| @@ -0,0 +1,7 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
| { | |||||
| // TODO: complete the implementation | |||||
| public class LSTMCellArgs : LayerArgs | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -1,21 +1,30 @@ | |||||
| using System.Collections.Generic; | |||||
| using Newtonsoft.Json; | |||||
| using System.Collections.Generic; | |||||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | namespace Tensorflow.Keras.ArgsDefinition.Rnn | ||||
| { | { | ||||
| public class RNNArgs : LayerArgs | |||||
| public class RNNArgs : AutoSerializeLayerArgs | |||||
| { | { | ||||
| public interface IRnnArgCell : ILayer | public interface IRnnArgCell : ILayer | ||||
| { | { | ||||
| object state_size { get; } | object state_size { get; } | ||||
| } | } | ||||
| [JsonProperty("cell")] | |||||
| // TODO: the cell should be serialized with `serialize_keras_object`. | |||||
| public IRnnArgCell Cell { get; set; } = null; | public IRnnArgCell Cell { get; set; } = null; | ||||
| [JsonProperty("return_sequences")] | |||||
| public bool ReturnSequences { get; set; } = false; | public bool ReturnSequences { get; set; } = false; | ||||
| [JsonProperty("return_state")] | |||||
| public bool ReturnState { get; set; } = false; | public bool ReturnState { get; set; } = false; | ||||
| [JsonProperty("go_backwards")] | |||||
| public bool GoBackwards { get; set; } = false; | public bool GoBackwards { get; set; } = false; | ||||
| [JsonProperty("stateful")] | |||||
| public bool Stateful { get; set; } = false; | public bool Stateful { get; set; } = false; | ||||
| [JsonProperty("unroll")] | |||||
| public bool Unroll { get; set; } = false; | public bool Unroll { get; set; } = false; | ||||
| [JsonProperty("time_major")] | |||||
| public bool TimeMajor { get; set; } = false; | public bool TimeMajor { get; set; } = false; | ||||
| // TODO: Add `num_constants` and `zero_output_for_mask`. | |||||
| public Dictionary<string, object> Kwargs { get; set; } = null; | public Dictionary<string, object> Kwargs { get; set; } = null; | ||||
| public int Units { get; set; } | public int Units { get; set; } | ||||
| @@ -0,0 +1,50 @@ | |||||
| using Newtonsoft.Json; | |||||
| using Newtonsoft.Json.Converters; | |||||
| using Newtonsoft.Json.Linq; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.Common | |||||
| { | |||||
| public class CustomizedActivationJsonConverter : JsonConverter | |||||
| { | |||||
| public override bool CanConvert(Type objectType) | |||||
| { | |||||
| return objectType == typeof(Activation); | |||||
| } | |||||
| public override bool CanRead => true; | |||||
| public override bool CanWrite => true; | |||||
| public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | |||||
| { | |||||
| if (value is null) | |||||
| { | |||||
| var token = JToken.FromObject(""); | |||||
| token.WriteTo(writer); | |||||
| } | |||||
| else if (value is not Activation) | |||||
| { | |||||
| throw new TypeError($"Unable to use `CustomizedActivationJsonConverter` to serialize the type {value.GetType()}."); | |||||
| } | |||||
| else | |||||
| { | |||||
| var token = JToken.FromObject((value as Activation)!.GetType().Name); | |||||
| token.WriteTo(writer); | |||||
| } | |||||
| } | |||||
| public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| //var dims = serializer.Deserialize(reader, typeof(string)); | |||||
| //if (dims is null) | |||||
| //{ | |||||
| // throw new ValueError("Cannot deserialize 'null' to `Activation`."); | |||||
| //} | |||||
| //return new Shape((long[])(dims!)); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,48 @@ | |||||
| using Newtonsoft.Json.Linq; | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.Common | |||||
| { | |||||
| public class CustomizedAxisJsonConverter : JsonConverter | |||||
| { | |||||
| public override bool CanConvert(Type objectType) | |||||
| { | |||||
| return objectType == typeof(Axis); | |||||
| } | |||||
| public override bool CanRead => true; | |||||
| public override bool CanWrite => true; | |||||
| public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | |||||
| { | |||||
| if (value is null) | |||||
| { | |||||
| var token = JToken.FromObject(new int[] { }); | |||||
| token.WriteTo(writer); | |||||
| } | |||||
| else if (value is not Axis) | |||||
| { | |||||
| throw new TypeError($"Unable to use `CustomizedAxisJsonConverter` to serialize the type {value.GetType()}."); | |||||
| } | |||||
| else | |||||
| { | |||||
| var token = JToken.FromObject((value as Axis)!.axis); | |||||
| token.WriteTo(writer); | |||||
| } | |||||
| } | |||||
| public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||||
| { | |||||
| var axis = serializer.Deserialize(reader, typeof(long[])); | |||||
| if (axis is null) | |||||
| { | |||||
| throw new ValueError("Cannot deserialize 'null' to `Axis`."); | |||||
| } | |||||
| return new Axis((int[])(axis!)); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,73 @@ | |||||
| using Newtonsoft.Json; | |||||
| using Newtonsoft.Json.Converters; | |||||
| using Newtonsoft.Json.Linq; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.Common | |||||
| { | |||||
| public class CustomizedNodeConfigJsonConverter : JsonConverter | |||||
| { | |||||
| public override bool CanConvert(Type objectType) | |||||
| { | |||||
| return objectType == typeof(NodeConfig); | |||||
| } | |||||
| public override bool CanRead => true; | |||||
| public override bool CanWrite => true; | |||||
| public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | |||||
| { | |||||
| if (value is null) | |||||
| { | |||||
| var token = JToken.FromObject(null); | |||||
| token.WriteTo(writer); | |||||
| } | |||||
| else if (value is not NodeConfig) | |||||
| { | |||||
| throw new TypeError($"Unable to use `CustomizedNodeConfigJsonConverter` to serialize the type {value.GetType()}."); | |||||
| } | |||||
| else | |||||
| { | |||||
| var config = value as NodeConfig; | |||||
| var token = JToken.FromObject(new object[] { config!.Name, config.NodeIndex, config.TensorIndex }); | |||||
| token.WriteTo(writer); | |||||
| } | |||||
| } | |||||
| public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||||
| { | |||||
| var values = serializer.Deserialize(reader, typeof(object[])) as object[]; | |||||
| if (values is null) | |||||
| { | |||||
| throw new ValueError("Cannot deserialize 'null' to `Shape`."); | |||||
| } | |||||
| if(values.Length != 3) | |||||
| { | |||||
| throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); | |||||
| } | |||||
| if (values[0] is not string) | |||||
| { | |||||
| throw new TypeError($"The first value of `NodeConfig` is expected to be `string`, but got `{values[0].GetType().Name}`"); | |||||
| } | |||||
| if (values[1] is not int) | |||||
| { | |||||
| throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[1].GetType().Name}`"); | |||||
| } | |||||
| if (values[2] is not int) | |||||
| { | |||||
| throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[2].GetType().Name}`"); | |||||
| } | |||||
| return new NodeConfig() | |||||
| { | |||||
| Name = values[0] as string, | |||||
| NodeIndex = (int)values[1], | |||||
| TensorIndex = (int)values[2] | |||||
| }; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,67 @@ | |||||
| using Newtonsoft.Json; | |||||
| using Newtonsoft.Json.Converters; | |||||
| using Newtonsoft.Json.Linq; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.Common | |||||
| { | |||||
| public class CustomizedShapeJsonConverter: JsonConverter | |||||
| { | |||||
| public override bool CanConvert(Type objectType) | |||||
| { | |||||
| return objectType == typeof(Shape); | |||||
| } | |||||
| public override bool CanRead => true; | |||||
| public override bool CanWrite => true; | |||||
| public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | |||||
| { | |||||
| if(value is null) | |||||
| { | |||||
| var token = JToken.FromObject(null); | |||||
| token.WriteTo(writer); | |||||
| } | |||||
| else if(value is not Shape) | |||||
| { | |||||
| throw new TypeError($"Unable to use `CustomizedShapeJsonConverter` to serialize the type {value.GetType()}."); | |||||
| } | |||||
| else | |||||
| { | |||||
| var shape = (value as Shape)!; | |||||
| long?[] dims = new long?[shape.ndim]; | |||||
| for(int i = 0; i < dims.Length; i++) | |||||
| { | |||||
| if (shape.dims[i] == -1) | |||||
| { | |||||
| dims[i] = null; | |||||
| } | |||||
| else | |||||
| { | |||||
| dims[i] = shape.dims[i]; | |||||
| } | |||||
| } | |||||
| var token = JToken.FromObject(dims); | |||||
| token.WriteTo(writer); | |||||
| } | |||||
| } | |||||
| public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||||
| { | |||||
| var dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; | |||||
| if(dims is null) | |||||
| { | |||||
| throw new ValueError("Cannot deserialize 'null' to `Shape`."); | |||||
| } | |||||
| long[] convertedDims = new long[dims.Length]; | |||||
| for(int i = 0; i < dims.Length; i++) | |||||
| { | |||||
| convertedDims[i] = dims[i] ?? (-1); | |||||
| } | |||||
| return new Shape(convertedDims); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -16,23 +16,27 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Specifies the ndim, dtype and shape of every input to a layer. | /// Specifies the ndim, dtype and shape of every input to a layer. | ||||
| /// </summary> | /// </summary> | ||||
| public class InputSpec | |||||
| public class InputSpec: IKerasConfigable | |||||
| { | { | ||||
| public int? ndim; | public int? ndim; | ||||
| public int? max_ndim; | |||||
| public int? min_ndim; | public int? min_ndim; | ||||
| Dictionary<int, int> axes; | Dictionary<int, int> axes; | ||||
| Shape shape; | Shape shape; | ||||
| TF_DataType dtype; | |||||
| public int[] AllAxisDim; | public int[] AllAxisDim; | ||||
| public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid, | public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| int? ndim = null, | int? ndim = null, | ||||
| int? min_ndim = null, | int? min_ndim = null, | ||||
| int? max_ndim = null, | |||||
| Dictionary<int, int> axes = null, | Dictionary<int, int> axes = null, | ||||
| Shape shape = null) | Shape shape = null) | ||||
| { | { | ||||
| @@ -41,7 +45,9 @@ namespace Tensorflow.Keras.Engine | |||||
| axes = new Dictionary<int, int>(); | axes = new Dictionary<int, int>(); | ||||
| this.axes = axes; | this.axes = axes; | ||||
| this.min_ndim = min_ndim; | this.min_ndim = min_ndim; | ||||
| this.max_ndim = max_ndim; | |||||
| this.shape = shape; | this.shape = shape; | ||||
| this.dtype = dtype; | |||||
| if (ndim == null && shape != null) | if (ndim == null && shape != null) | ||||
| this.ndim = shape.ndim; | this.ndim = shape.ndim; | ||||
| @@ -49,7 +55,30 @@ namespace Tensorflow.Keras.Engine | |||||
| AllAxisDim = axes.Select(x => x.Value).ToArray(); | AllAxisDim = axes.Select(x => x.Value).ToArray(); | ||||
| } | } | ||||
| public IKerasConfig get_config() | |||||
| { | |||||
| return new Config() | |||||
| { | |||||
| DType = dtype == TF_DataType.DtInvalid ? null : dtype, | |||||
| Shape = shape, | |||||
| Ndim = ndim, | |||||
| MinNdim = min_ndim, | |||||
| MaxNdim = max_ndim, | |||||
| Axes = axes.ToDictionary(x => x.Key.ToString(), x => x.Value) | |||||
| }; | |||||
| } | |||||
| public override string ToString() | public override string ToString() | ||||
| => $"ndim={ndim}, min_ndim={min_ndim}, axes={axes.Count}"; | => $"ndim={ndim}, min_ndim={min_ndim}, axes={axes.Count}"; | ||||
| public class Config: IKerasConfig | |||||
| { | |||||
| public TF_DataType? DType { get; set; } | |||||
| public Shape Shape { get; set; } | |||||
| public int? Ndim { get; set; } | |||||
| public int? MinNdim { get;set; } | |||||
| public int? MaxNdim { get;set; } | |||||
| public IDictionary<string, int> Axes { get; set; } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,10 +1,12 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | |||||
| using Tensorflow.Training; | |||||
| namespace Tensorflow.Keras | namespace Tensorflow.Keras | ||||
| { | { | ||||
| public interface ILayer | |||||
| public interface ILayer: IWithTrackable, IKerasConfigable | |||||
| { | { | ||||
| string Name { get; } | string Name { get; } | ||||
| bool Trainable { get; } | bool Trainable { get; } | ||||
| @@ -19,8 +21,8 @@ namespace Tensorflow.Keras | |||||
| List<IVariableV1> NonTrainableWeights { get; } | List<IVariableV1> NonTrainableWeights { get; } | ||||
| Shape OutputShape { get; } | Shape OutputShape { get; } | ||||
| Shape BatchInputShape { get; } | Shape BatchInputShape { get; } | ||||
| TensorShapeConfig BuildInputShape { get; } | |||||
| TF_DataType DType { get; } | TF_DataType DType { get; } | ||||
| int count_params(); | int count_params(); | ||||
| LayerArgs get_config(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,5 +1,5 @@ | |||||
| using System; | using System; | ||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.ArgsDefinition.Reshaping; | |||||
| using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
| namespace Tensorflow.Keras.Layers | namespace Tensorflow.Keras.Layers | ||||
| @@ -0,0 +1,15 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.Saving | |||||
| { | |||||
| public interface IKerasConfig | |||||
| { | |||||
| } | |||||
| public interface IKerasConfigable | |||||
| { | |||||
| IKerasConfig get_config(); | |||||
| } | |||||
| } | |||||
| @@ -1,4 +1,5 @@ | |||||
| using System; | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| @@ -6,11 +7,15 @@ using Tensorflow.Keras.Engine; | |||||
| namespace Tensorflow.Keras.Saving | namespace Tensorflow.Keras.Saving | ||||
| { | { | ||||
| public class LayerConfig | |||||
| public class LayerConfig: IKerasConfig | |||||
| { | { | ||||
| [JsonProperty("name")] | |||||
| public string Name { get; set; } | public string Name { get; set; } | ||||
| [JsonProperty("class_name")] | |||||
| public string ClassName { get; set; } | public string ClassName { get; set; } | ||||
| [JsonProperty("config")] | |||||
| public LayerArgs Config { get; set; } | public LayerArgs Config { get; set; } | ||||
| [JsonProperty("inbound_nodes")] | |||||
| public List<NodeConfig> InboundNodes { get; set; } | public List<NodeConfig> InboundNodes { get; set; } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,15 +1,20 @@ | |||||
| using System; | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| namespace Tensorflow.Keras.Saving | namespace Tensorflow.Keras.Saving | ||||
| { | { | ||||
| public class ModelConfig | |||||
| public class ModelConfig : IKerasConfig | |||||
| { | { | ||||
| [JsonProperty("name")] | |||||
| public string Name { get; set; } | public string Name { get; set; } | ||||
| [JsonProperty("layers")] | |||||
| public List<LayerConfig> Layers { get; set; } | public List<LayerConfig> Layers { get; set; } | ||||
| [JsonProperty("input_layers")] | |||||
| public List<NodeConfig> InputLayers { get; set; } | public List<NodeConfig> InputLayers { get; set; } | ||||
| [JsonProperty("output_layers")] | |||||
| public List<NodeConfig> OutputLayers { get; set; } | public List<NodeConfig> OutputLayers { get; set; } | ||||
| public override string ToString() | public override string ToString() | ||||
| @@ -1,10 +1,13 @@ | |||||
| using System; | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.Common; | |||||
| namespace Tensorflow.Keras.Saving | namespace Tensorflow.Keras.Saving | ||||
| { | { | ||||
| public class NodeConfig | |||||
| [JsonConverter(typeof(CustomizedNodeConfigJsonConverter))] | |||||
| public class NodeConfig : IKerasConfig | |||||
| { | { | ||||
| public string Name { get; set; } | public string Name { get; set; } | ||||
| public int NodeIndex { get; set; } | public int NodeIndex { get; set; } | ||||
| @@ -0,0 +1,35 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using Tensorflow.Train; | |||||
| namespace Tensorflow.Keras.Saving.SavedModel | |||||
| { | |||||
| public interface ISerializedAttributes | |||||
| { | |||||
| IDictionary<string, Trackable> Functions { get; } | |||||
| IDictionary<string, Trackable> CheckpointableObjects { get; } | |||||
| /// <summary> | |||||
| /// Returns functions to attach to the root object during serialization. | |||||
| /// </summary> | |||||
| IDictionary<string, Trackable> FunctionsToSerialize { get; } | |||||
| /// <summary> | |||||
| /// Returns objects to attach to the root object during serialization. | |||||
| /// </summary> | |||||
| IDictionary<string, Trackable> ObjectsToSerialize{get; } | |||||
| /// <summary> | |||||
| /// Saves function dictionary, and validates dictionary values. | |||||
| /// </summary> | |||||
| /// <param name="function_dict"></param> | |||||
| IDictionary<string, Trackable> set_and_validate_functions(IDictionary<string, Trackable> function_dict); | |||||
| /// <summary> | |||||
| /// Saves objects to a dictionary, and validates the values. | |||||
| /// </summary> | |||||
| /// <param name="object_dict"></param> | |||||
| IDictionary<string, Trackable> set_and_validate_objects(IDictionary<string, Trackable> object_dict); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,21 @@ | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| namespace Tensorflow.Keras.Saving | |||||
| { | |||||
| public class TensorShapeConfig | |||||
| { | |||||
| [JsonProperty("class_name")] | |||||
| public string ClassName { get; set; } = "TensorShape"; | |||||
| [JsonProperty("items")] | |||||
| public long?[] Items { get; set; } | |||||
| public static implicit operator Shape(TensorShapeConfig shape) | |||||
| => shape == null ? null : new Shape(shape.Items.Select(x => x.HasValue ? x.Value : -1).ToArray()); | |||||
| public static implicit operator TensorShapeConfig(Shape shape) | |||||
| => new TensorShapeConfig() { Items = shape.dims.Select<long, long?>(x => x == -1 ? null : x).ToArray() }; | |||||
| } | |||||
| } | |||||
| @@ -9,10 +9,52 @@ namespace Tensorflow.ModelSaving | |||||
| /// </summary> | /// </summary> | ||||
| public class SaveOptions | public class SaveOptions | ||||
| { | { | ||||
| bool save_debug_info; | |||||
| public bool save_debug_info = false; | |||||
| public IList<string>? namespace_white_list { get; set; } = null; | |||||
| public IDictionary<string, object>? function_aliases { get; set; } = null; | |||||
| public string? experimental_io_device { get; set; } = null; | |||||
| // TODO: experimental | |||||
| public VariablePolicy experimental_variable_policy { get; set; } = VariablePolicy.None; | |||||
| public bool experimental_custom_gradients { get; set; } = true; | |||||
| public SaveOptions(bool save_debug_info = false) | public SaveOptions(bool save_debug_info = false) | ||||
| { | { | ||||
| this.save_debug_info = save_debug_info; | this.save_debug_info = save_debug_info; | ||||
| } | } | ||||
| } | } | ||||
| public class VariablePolicy | |||||
| { | |||||
| public string Policy { get; } | |||||
| private VariablePolicy(string policy) | |||||
| { | |||||
| Policy = policy; | |||||
| } | |||||
| public static VariablePolicy None = new(null); | |||||
| public static VariablePolicy SAVE_VARIABLE_DEVICES = new("save_variable_devices"); | |||||
| public static VariablePolicy EXPAND_DISTRIBUTED_VARIABLES = new("expand_distributed_variables"); | |||||
| public bool save_variable_devices() | |||||
| { | |||||
| return this != VariablePolicy.None; | |||||
| } | |||||
| /// <summary> | |||||
| /// Tries to convert `obj` to a VariablePolicy instance. | |||||
| /// </summary> | |||||
| /// <param name="obj"></param> | |||||
| /// <returns></returns> | |||||
| public static VariablePolicy from_obj(object obj) | |||||
| { | |||||
| if (obj is null) return VariablePolicy.None; | |||||
| if (obj is VariablePolicy) return (VariablePolicy)obj; | |||||
| var key = obj.ToString().ToLower(); | |||||
| return key switch | |||||
| { | |||||
| null => VariablePolicy.None, | |||||
| "save_variable_devices" => VariablePolicy.SAVE_VARIABLE_DEVICES, | |||||
| "expand_distributed_variables" => VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES, | |||||
| _ => throw new ValueError($"Received invalid VariablePolicy value: {obj}.") | |||||
| }; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| @@ -14,20 +14,29 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using Newtonsoft.Json; | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.Common; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public record Axis(params int[] axis) | |||||
| [JsonConverter(typeof(CustomizedAxisJsonConverter))] | |||||
| public class Axis | |||||
| { | { | ||||
| public int[] axis { get; set; } | |||||
| public int size => axis == null ? -1 : axis.Length; | public int size => axis == null ? -1 : axis.Length; | ||||
| public bool IsScalar { get; init; } | public bool IsScalar { get; init; } | ||||
| public int this[int index] => axis[index]; | public int this[int index] => axis[index]; | ||||
| public Axis(params int[] axis) | |||||
| { | |||||
| this.axis = axis; | |||||
| } | |||||
| public static implicit operator int[]?(Axis axis) | public static implicit operator int[]?(Axis axis) | ||||
| => axis?.axis; | => axis?.axis; | ||||
| @@ -14,14 +14,17 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using Newtonsoft.Json; | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.Common; | |||||
| using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| [JsonConverter(typeof(CustomizedShapeJsonConverter))] | |||||
| public class Shape | public class Shape | ||||
| { | { | ||||
| public int ndim => _dims == null ? -1 : _dims.Length; | public int ndim => _dims == null ? -1 : _dims.Length; | ||||
| @@ -14,6 +14,8 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Collections.Generic; | |||||
| namespace Tensorflow.Operations.Initializers | namespace Tensorflow.Operations.Initializers | ||||
| { | { | ||||
| public class Constant<T> : IInitializer | public class Constant<T> : IInitializer | ||||
| @@ -22,11 +24,19 @@ namespace Tensorflow.Operations.Initializers | |||||
| T value; | T value; | ||||
| bool _verify_shape; | bool _verify_shape; | ||||
| private readonly Dictionary<string, object> _config; | |||||
| public string ClassName => "Constant"; | |||||
| public IDictionary<string, object> Config => _config; | |||||
| public Constant(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false) | public Constant(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false) | ||||
| { | { | ||||
| this.value = value; | this.value = value; | ||||
| this.dtype = dtype; | this.dtype = dtype; | ||||
| _verify_shape = verify_shape; | _verify_shape = verify_shape; | ||||
| _config = new Dictionary<string, object>(); | |||||
| _config["value"] = this.value; | |||||
| } | } | ||||
| public Tensor Apply(InitializerArgs args) | public Tensor Apply(InitializerArgs args) | ||||
| @@ -14,10 +14,17 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Collections.Generic; | |||||
| namespace Tensorflow.Operations.Initializers | namespace Tensorflow.Operations.Initializers | ||||
| { | { | ||||
| public class GlorotUniform : VarianceScaling | public class GlorotUniform : VarianceScaling | ||||
| { | { | ||||
| private readonly Dictionary<string, object> _config; | |||||
| public override string ClassName => "GlorotUniform"; | |||||
| public override IDictionary<string, object> Config => _config; | |||||
| public GlorotUniform(float scale = 1.0f, | public GlorotUniform(float scale = 1.0f, | ||||
| string mode = "FAN_AVG", | string mode = "FAN_AVG", | ||||
| bool uniform = true, | bool uniform = true, | ||||
| @@ -28,7 +35,8 @@ namespace Tensorflow.Operations.Initializers | |||||
| seed: seed, | seed: seed, | ||||
| dtype: dtype) | dtype: dtype) | ||||
| { | { | ||||
| _config = new Dictionary<string, object>(); | |||||
| _config["seed"] = _seed; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -14,10 +14,17 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using Newtonsoft.Json; | |||||
| using System.Collections.Generic; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public interface IInitializer | public interface IInitializer | ||||
| { | { | ||||
| [JsonProperty("class_name")] | |||||
| string ClassName { get; } | |||||
| [JsonProperty("config")] | |||||
| IDictionary<string, object> Config { get; } | |||||
| Tensor Apply(InitializerArgs args); | Tensor Apply(InitializerArgs args); | ||||
| } | } | ||||
| } | } | ||||
| @@ -14,12 +14,19 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Collections.Generic; | |||||
| namespace Tensorflow.Operations.Initializers | namespace Tensorflow.Operations.Initializers | ||||
| { | { | ||||
| public class Ones : IInitializer | public class Ones : IInitializer | ||||
| { | { | ||||
| private TF_DataType dtype; | private TF_DataType dtype; | ||||
| private readonly Dictionary<string, object> _config; | |||||
| public string ClassName => "Ones"; | |||||
| public IDictionary<string, object> Config => new Dictionary<string, object>(); | |||||
| public Ones(TF_DataType dtype = TF_DataType.TF_FLOAT) | public Ones(TF_DataType dtype = TF_DataType.TF_FLOAT) | ||||
| { | { | ||||
| this.dtype = dtype; | this.dtype = dtype; | ||||
| @@ -1,4 +1,4 @@ | |||||
| /***************************************************************************** | |||||
| /***************************************************************************** | |||||
| Copyright 2023 Haiping Chen. All Rights Reserved. | Copyright 2023 Haiping Chen. All Rights Reserved. | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| @@ -19,6 +19,7 @@ using System.Linq; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Operations.Initializers; | namespace Tensorflow.Operations.Initializers; | ||||
| using System.Collections.Generic; | |||||
| public class Orthogonal : IInitializer | public class Orthogonal : IInitializer | ||||
| { | { | ||||
| @@ -31,6 +32,10 @@ public class Orthogonal : IInitializer | |||||
| _seed = seed; | _seed = seed; | ||||
| } | } | ||||
| private readonly Dictionary<string, object> _config; | |||||
| public string ClassName => "Orthogonal"; | |||||
| public IDictionary<string, object> Config => throw new NotImplementedException(); | |||||
| public Tensor Apply(InitializerArgs args) | public Tensor Apply(InitializerArgs args) | ||||
| { | { | ||||
| return _generate_init_val(args.Shape, args.DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : args.DType); | return _generate_init_val(args.Shape, args.DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : args.DType); | ||||
| @@ -14,6 +14,8 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Collections.Generic; | |||||
| namespace Tensorflow.Operations.Initializers | namespace Tensorflow.Operations.Initializers | ||||
| { | { | ||||
| public class RandomNormal : IInitializer | public class RandomNormal : IInitializer | ||||
| @@ -23,6 +25,11 @@ namespace Tensorflow.Operations.Initializers | |||||
| private int? seed; | private int? seed; | ||||
| private TF_DataType dtype; | private TF_DataType dtype; | ||||
| private readonly Dictionary<string, object> _config; | |||||
| public string ClassName => "RandomNormal"; | |||||
| public IDictionary<string, object> Config => _config; | |||||
| public RandomNormal(float mean = 0.0f, | public RandomNormal(float mean = 0.0f, | ||||
| float stddev = 0.05f, | float stddev = 0.05f, | ||||
| int? seed = null, | int? seed = null, | ||||
| @@ -32,6 +39,11 @@ namespace Tensorflow.Operations.Initializers | |||||
| this.stddev = stddev; | this.stddev = stddev; | ||||
| this.seed = seed; | this.seed = seed; | ||||
| this.dtype = dtype; | this.dtype = dtype; | ||||
| _config = new Dictionary<string, object>(); | |||||
| _config["mean"] = this.mean; | |||||
| _config["stddev"] = this.stddev; | |||||
| _config["seed"] = this.seed; | |||||
| } | } | ||||
| public Tensor Apply(InitializerArgs args) | public Tensor Apply(InitializerArgs args) | ||||
| @@ -14,6 +14,8 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Collections.Generic; | |||||
| namespace Tensorflow.Operations.Initializers | namespace Tensorflow.Operations.Initializers | ||||
| { | { | ||||
| public class RandomUniform : IInitializer | public class RandomUniform : IInitializer | ||||
| @@ -23,12 +25,22 @@ namespace Tensorflow.Operations.Initializers | |||||
| private float maxval; | private float maxval; | ||||
| private TF_DataType dtype; | private TF_DataType dtype; | ||||
| private readonly Dictionary<string, object> _config; | |||||
| public string ClassName => "RandomUniform"; | |||||
| public IDictionary<string, object> Config => _config; | |||||
| public RandomUniform(TF_DataType dtype = TF_DataType.TF_FLOAT, float minval = -0.05f, float maxval = 0.05f, int? seed = null) | public RandomUniform(TF_DataType dtype = TF_DataType.TF_FLOAT, float minval = -0.05f, float maxval = 0.05f, int? seed = null) | ||||
| { | { | ||||
| this.dtype = dtype; | this.dtype = dtype; | ||||
| this.minval = minval; | this.minval = minval; | ||||
| this.maxval = maxval; | this.maxval = maxval; | ||||
| this.seed = seed; | this.seed = seed; | ||||
| _config = new Dictionary<string, object>(); | |||||
| _config["minval"] = this.minval; | |||||
| _config["maxval"] = this.maxval; | |||||
| _config["seed"] = this.seed; | |||||
| } | } | ||||
| public Tensor Apply(InitializerArgs args) | public Tensor Apply(InitializerArgs args) | ||||
| @@ -14,6 +14,8 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Collections.Generic; | |||||
| namespace Tensorflow.Operations.Initializers | namespace Tensorflow.Operations.Initializers | ||||
| { | { | ||||
| public class TruncatedNormal : IInitializer | public class TruncatedNormal : IInitializer | ||||
| @@ -23,6 +25,11 @@ namespace Tensorflow.Operations.Initializers | |||||
| private int? seed; | private int? seed; | ||||
| private TF_DataType dtype; | private TF_DataType dtype; | ||||
| private readonly Dictionary<string, object> _config; | |||||
| public string ClassName => "TruncatedNormal"; | |||||
| public IDictionary<string, object> Config => _config; | |||||
| public TruncatedNormal(float mean = 0.0f, | public TruncatedNormal(float mean = 0.0f, | ||||
| float stddev = 1.0f, | float stddev = 1.0f, | ||||
| int? seed = null, | int? seed = null, | ||||
| @@ -32,6 +39,10 @@ namespace Tensorflow.Operations.Initializers | |||||
| this.stddev = stddev; | this.stddev = stddev; | ||||
| this.seed = seed; | this.seed = seed; | ||||
| this.dtype = dtype; | this.dtype = dtype; | ||||
| _config = new Dictionary<string, object>(); | |||||
| _config["mean"] = this.mean; | |||||
| _config["stddev"] = this.stddev; | |||||
| _config["seed"] = this.seed; | |||||
| } | } | ||||
| public Tensor Apply(InitializerArgs args) | public Tensor Apply(InitializerArgs args) | ||||
| @@ -15,7 +15,9 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Linq.Expressions; | |||||
| namespace Tensorflow.Operations.Initializers | namespace Tensorflow.Operations.Initializers | ||||
| { | { | ||||
| @@ -30,6 +32,11 @@ namespace Tensorflow.Operations.Initializers | |||||
| protected int? _seed; | protected int? _seed; | ||||
| protected TF_DataType _dtype; | protected TF_DataType _dtype; | ||||
| protected bool _uniform; | protected bool _uniform; | ||||
| private readonly Dictionary<string, object> _config; | |||||
| public virtual string ClassName => "VarianceScaling"; | |||||
| public virtual IDictionary<string, object> Config => _config; | |||||
| public VarianceScaling(float factor = 2.0f, | public VarianceScaling(float factor = 2.0f, | ||||
| string mode = "FAN_IN", | string mode = "FAN_IN", | ||||
| @@ -50,6 +57,12 @@ namespace Tensorflow.Operations.Initializers | |||||
| _seed = seed; | _seed = seed; | ||||
| _dtype = dtype; | _dtype = dtype; | ||||
| _uniform = uniform; | _uniform = uniform; | ||||
| _config = new(); | |||||
| _config["scale"] = _scale; | |||||
| _config["mode"] = _mode; | |||||
| _config["distribution"] = _distribution; | |||||
| _config["seed"] = _seed; | |||||
| } | } | ||||
| public Tensor Apply(InitializerArgs args) | public Tensor Apply(InitializerArgs args) | ||||
| @@ -14,6 +14,8 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Collections.Generic; | |||||
| namespace Tensorflow.Operations.Initializers | namespace Tensorflow.Operations.Initializers | ||||
| { | { | ||||
| public class Zeros : IInitializer | public class Zeros : IInitializer | ||||
| @@ -21,6 +23,9 @@ namespace Tensorflow.Operations.Initializers | |||||
| Shape shape; | Shape shape; | ||||
| TF_DataType dtype; | TF_DataType dtype; | ||||
| public string ClassName => "Zeros"; | |||||
| public IDictionary<string, object> Config => new Dictionary<string, object>(); | |||||
| public Zeros(Shape shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT) | public Zeros(Shape shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT) | ||||
| { | { | ||||
| this.shape = shape; | this.shape = shape; | ||||
| @@ -20,7 +20,9 @@ using Tensorflow.Keras; | |||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | using Tensorflow.Keras.ArgsDefinition.Rnn; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | |||||
| using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
| using Tensorflow.Train; | |||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -75,6 +77,8 @@ namespace Tensorflow | |||||
| public Shape BatchInputShape => throw new NotImplementedException(); | public Shape BatchInputShape => throw new NotImplementedException(); | ||||
| public TensorShapeConfig BuildInputShape => throw new NotImplementedException(); | |||||
| public TF_DataType DType => throw new NotImplementedException(); | public TF_DataType DType => throw new NotImplementedException(); | ||||
| protected bool built = false; | protected bool built = false; | ||||
| public bool Built => built; | public bool Built => built; | ||||
| @@ -143,7 +147,7 @@ namespace Tensorflow | |||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| } | } | ||||
| public LayerArgs get_config() | |||||
| public IKerasConfig get_config() | |||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| } | } | ||||
| @@ -152,5 +156,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| } | } | ||||
| public Trackable GetTrackable() { throw new NotImplementedException(); } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,6 +1,9 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Xml.Linq; | |||||
| using Tensorflow.Contexts; | |||||
| using Tensorflow.Eager; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Operations | namespace Tensorflow.Operations | ||||
| @@ -17182,17 +17185,47 @@ namespace Tensorflow.Operations | |||||
| /// path in the input checkpoint_prefixes. This is useful when those paths are non | /// path in the input checkpoint_prefixes. This is useful when those paths are non | ||||
| /// user-facing temporary locations. | /// user-facing temporary locations. | ||||
| /// </remarks> | /// </remarks> | ||||
| public static Operation merge_v2checkpoints(Tensor checkpoint_prefixes, Tensor destination_prefix, bool? delete_old_dirs = null, string name = "MergeV2Checkpoints") | |||||
| { | |||||
| public static Operation merge_v2_checkpoints(Tensor[] checkpoint_prefixes, Tensor destination_prefix, bool delete_old_dirs = true, bool allow_missing_files = false, string name = "MergeV2Checkpoints") | |||||
| { | |||||
| var ctx = tf.Context; | |||||
| if (ctx.executing_eagerly()) | |||||
| { | |||||
| var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("MergeV2Checkpoints", name, | |||||
| checkpoint_prefixes, destination_prefix, "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files)); | |||||
| result = null; | |||||
| return null; | |||||
| //try | |||||
| //{ | |||||
| // var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("MergeV2Checkpoints", name, | |||||
| // new object[] { checkpoint_prefixes, destination_prefix, "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files })); | |||||
| // result = null; | |||||
| // return null; | |||||
| //} | |||||
| //catch (System.Exception) | |||||
| //{ | |||||
| // return merge_v2_checkpoints_eager_fallback(checkpoint_prefixes, destination_prefix, delete_old_dirs: delete_old_dirs, | |||||
| // allow_missing_files: allow_missing_files, name: name, ctx: ctx); | |||||
| //} | |||||
| } | |||||
| var dict = new Dictionary<string, object>(); | var dict = new Dictionary<string, object>(); | ||||
| dict["checkpoint_prefixes"] = checkpoint_prefixes; | dict["checkpoint_prefixes"] = checkpoint_prefixes; | ||||
| dict["destination_prefix"] = destination_prefix; | dict["destination_prefix"] = destination_prefix; | ||||
| if (delete_old_dirs.HasValue) | |||||
| dict["delete_old_dirs"] = delete_old_dirs.Value; | |||||
| dict["delete_old_dirs"] = delete_old_dirs; | |||||
| var op = tf.OpDefLib._apply_op_helper("MergeV2Checkpoints", name: name, keywords: dict); | var op = tf.OpDefLib._apply_op_helper("MergeV2Checkpoints", name: name, keywords: dict); | ||||
| return op; | return op; | ||||
| } | } | ||||
| //public static Operation merge_v2_checkpoints_eager_fallback(Tensor[] checkpoint_prefixes, Tensor destination_prefix, bool delete_old_dirs, bool allow_missing_files, string name, Context ctx) | |||||
| //{ | |||||
| // checkpoint_prefixes = ops.convert_to_tensor(checkpoint_prefixes, TF_DataType.TF_STRING); | |||||
| // destination_prefix = ops.convert_to_tensor(destination_prefix, TF_DataType.TF_STRING); | |||||
| // var inputs_flat = new Tensor[] { checkpoint_prefixes, destination_prefix }; | |||||
| // var attrs = new object[] { "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files }; | |||||
| // var result = execute.quick_execute("MergeV2Checkpoints", 0, inputs_flat, attrs, ctx, name); | |||||
| // result = null; | |||||
| // return null; | |||||
| //} | |||||
| /// <summary> | /// <summary> | ||||
| /// Transforms a spectrogram into a form that's useful for speech recognition. | /// Transforms a spectrogram into a form that's useful for speech recognition. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -24259,6 +24292,12 @@ namespace Tensorflow.Operations | |||||
| /// </remarks> | /// </remarks> | ||||
| public static Tensor regex_full_match(Tensor input, Tensor pattern, string name = "RegexFullMatch") | public static Tensor regex_full_match(Tensor input, Tensor pattern, string name = "RegexFullMatch") | ||||
| { | { | ||||
| var ctx = tf.Context; | |||||
| if (ctx.executing_eagerly()) | |||||
| { | |||||
| var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("RegexFullMatch", name, input, pattern)); | |||||
| return result[0]; | |||||
| } | |||||
| var dict = new Dictionary<string, object>(); | var dict = new Dictionary<string, object>(); | ||||
| dict["input"] = input; | dict["input"] = input; | ||||
| dict["pattern"] = pattern; | dict["pattern"] = pattern; | ||||
| @@ -29744,6 +29783,12 @@ namespace Tensorflow.Operations | |||||
| /// </remarks> | /// </remarks> | ||||
| public static Tensor sharded_filename(Tensor basename, Tensor shard, Tensor num_shards, string name = "ShardedFilename") | public static Tensor sharded_filename(Tensor basename, Tensor shard, Tensor num_shards, string name = "ShardedFilename") | ||||
| { | { | ||||
| var ctx = tf.Context; | |||||
| if (ctx.executing_eagerly()) | |||||
| { | |||||
| var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("ShardedFilename", name, basename, shard, num_shards)); | |||||
| return result[0]; | |||||
| } | |||||
| var dict = new Dictionary<string, object>(); | var dict = new Dictionary<string, object>(); | ||||
| dict["basename"] = basename; | dict["basename"] = basename; | ||||
| dict["shard"] = shard; | dict["shard"] = shard; | ||||
| @@ -34668,6 +34713,12 @@ namespace Tensorflow.Operations | |||||
| /// </remarks> | /// </remarks> | ||||
| public static Tensor string_join(Tensor[] inputs, string separator = null, string name = "StringJoin") | public static Tensor string_join(Tensor[] inputs, string separator = null, string name = "StringJoin") | ||||
| { | { | ||||
| var ctx = tf.Context; | |||||
| if (ctx.executing_eagerly()) | |||||
| { | |||||
| var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("StringJoin", name, inputs, "separator", separator)); | |||||
| return result[0]; | |||||
| } | |||||
| var dict = new Dictionary<string, object>(); | var dict = new Dictionary<string, object>(); | ||||
| dict["inputs"] = inputs; | dict["inputs"] = inputs; | ||||
| if (separator != null) | if (separator != null) | ||||
| @@ -14,7 +14,9 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Linq; | |||||
| using Tensorflow.Contexts; | using Tensorflow.Contexts; | ||||
| using Tensorflow.Eager; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -23,11 +25,41 @@ namespace Tensorflow | |||||
| { | { | ||||
| public Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null) | public Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null) | ||||
| { | { | ||||
| var ctx = tf.Context; | |||||
| if (ctx.executing_eagerly()) | |||||
| { | |||||
| try | |||||
| { | |||||
| var result = tf.Runner.TFE_FastPathExecute( | |||||
| new FastPathOpExecInfo("SaveV2", name, new object[] { prefix, tensor_names, shape_and_slices, tensors })); | |||||
| result = null; | |||||
| return null; | |||||
| } | |||||
| catch (System.Exception) | |||||
| { | |||||
| return save_v2_eager_fallback(prefix, tensor_names, shape_and_slices, tensors, name, ctx); | |||||
| } | |||||
| } | |||||
| var _op = tf.OpDefLib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors }); | var _op = tf.OpDefLib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors }); | ||||
| return _op; | return _op; | ||||
| } | } | ||||
| public Operation save_v2_eager_fallback(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name, Context ctx) | |||||
| { | |||||
| DataType[] attr_dtypes; | |||||
| (attr_dtypes, tensors) = execute.onvert_to_mixed_eager_tensors(tensors, ctx); | |||||
| prefix = ops.convert_to_tensor(prefix, TF_DataType.TF_STRING); | |||||
| var tensor_names_tensor = ops.convert_to_tensor(tensor_names, TF_DataType.TF_STRING); | |||||
| var shape_and_slices_tensor = ops.convert_to_tensor(shape_and_slices, TF_DataType.TF_STRING); | |||||
| var inputs_flat = tensors.Concat(new Tensor[] { prefix, tensor_names_tensor, shape_and_slices_tensor }).ToArray(); | |||||
| var attrs = new object[] { "dtypes", attr_dtypes }; | |||||
| var result = execute.quick_execute("SaveV2", 0, inputs_flat, attrs, ctx, name); | |||||
| result = null; | |||||
| return null; | |||||
| } | |||||
| public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) | public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) | ||||
| { | { | ||||
| var _op = tf.OpDefLib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); | var _op = tf.OpDefLib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); | ||||
| @@ -17,6 +17,9 @@ | |||||
| using System; | using System; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Framework; | using Tensorflow.Framework; | ||||
| using Tensorflow.ModelSaving; | |||||
| using Tensorflow.Train; | |||||
| using Tensorflow.Variables; | |||||
| using static Tensorflow.CppShapeInferenceResult.Types; | using static Tensorflow.CppShapeInferenceResult.Types; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -38,6 +41,11 @@ namespace Tensorflow | |||||
| { | { | ||||
| return var is ResourceVariable; | return var is ResourceVariable; | ||||
| } | } | ||||
| public static bool is_resource_variable(Trackable var) | |||||
| { | |||||
| return var is BaseResourceVariable; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Creates a variable handle with information to do shape inference. | /// Creates a variable handle with information to do shape inference. | ||||
| @@ -171,5 +179,57 @@ namespace Tensorflow | |||||
| return HandleData.Parser.ParseFrom(handle.BufferToArray()); | return HandleData.Parser.ParseFrom(handle.BufferToArray()); | ||||
| } | } | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Copies an existing variable to a new graph, with no initializer. | |||||
| /// </summary> | |||||
| /// <param name="variable"></param> | |||||
| public static UninitializedVariable copy_to_graph_uninitialized(ResourceVariable variable) | |||||
| { | |||||
| var new_variable = new UninitializedVariable( | |||||
| trainable: variable.Trainable, | |||||
| shape: variable.shape, | |||||
| dtype: variable.dtype, | |||||
| name: variable.SharedName, | |||||
| aggregation: variable.Aggregation, | |||||
| extra_handle_data: null); | |||||
| new_variable._maybe_initialize_trackable(); | |||||
| return new_variable; | |||||
| } | |||||
| /// <summary> | |||||
| /// Writes additional information of the variable into the SavedObject proto. | |||||
| /// </summary> | |||||
| /// <param name="resource_variable"></param> | |||||
| /// <param name="proto"></param> | |||||
| /// <param name="options"></param> | |||||
| /// <param name="enforcing_naming"></param> | |||||
| public static void write_object_proto_for_resource_variable(BaseResourceVariable resource_variable, SavedObject proto, SaveOptions options, bool enforcing_naming = true) | |||||
| { | |||||
| // lack of API: `proto.Variable.SetInParent()`. | |||||
| if(enforcing_naming && !resource_variable.Name.EndsWith(":0")) | |||||
| { | |||||
| throw new ValueError($"Cowardly refusing to save variable {resource_variable.Name} because of " + | |||||
| $"unexpected suffix in the name (expected ':0') which won't be restored."); | |||||
| } | |||||
| if(proto.Variable is null) | |||||
| { | |||||
| proto.Variable = new SavedVariable(); | |||||
| } | |||||
| proto.Variable.Name = meta_graph.op_name(resource_variable.Name); | |||||
| proto.Variable.Trainable = resource_variable.Trainable; | |||||
| proto.Variable.Dtype = resource_variable.dtype.as_datatype_enum(); | |||||
| // TODO: lack of API `proto.Variable.Synchronization = resource_variable.synchronization.value`. | |||||
| proto.Variable.Aggregation = resource_variable.Aggregation; | |||||
| proto.Variable.Shape = resource_variable.shape.as_proto(); | |||||
| if (options.experimental_variable_policy.save_variable_devices()) | |||||
| { | |||||
| if (!string.IsNullOrEmpty(resource_variable.Device)) | |||||
| { | |||||
| proto.Variable.Device = resource_variable.Device; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -156,7 +156,7 @@ namespace Tensorflow { | |||||
| /// Nodes[0] is considered the root node. | /// Nodes[0] is considered the root node. | ||||
| /// </summary> | /// </summary> | ||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
| public pbc::RepeatedField<global::Tensorflow.SavedObject> Nodes { | |||||
| public pbc::RepeatedField<global::Tensorflow.SavedObject> Nodes { | |||||
| get { return nodes_; } | get { return nodes_; } | ||||
| } | } | ||||
| @@ -286,6 +286,7 @@ namespace Tensorflow { | |||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
| public SavedObject(SavedObject other) : this() { | public SavedObject(SavedObject other) : this() { | ||||
| children_ = other.children_.Clone(); | children_ = other.children_.Clone(); | ||||
| dependencies_ = other.dependencies_.Clone(); | |||||
| slotVariables_ = other.slotVariables_.Clone(); | slotVariables_ = other.slotVariables_.Clone(); | ||||
| saveableObjects_ = other.saveableObjects_.Clone(); | saveableObjects_ = other.saveableObjects_.Clone(); | ||||
| switch (other.KindCase) { | switch (other.KindCase) { | ||||
| @@ -328,6 +329,7 @@ namespace Tensorflow { | |||||
| private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_children_codec | private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_children_codec | ||||
| = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); | = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); | ||||
| private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); | private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); | ||||
| private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> dependencies_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); | |||||
| /// <summary> | /// <summary> | ||||
| /// Objects which this object depends on: named edges in the dependency | /// Objects which this object depends on: named edges in the dependency | ||||
| /// graph. | /// graph. | ||||
| @@ -338,6 +340,11 @@ namespace Tensorflow { | |||||
| public pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> Children { | public pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> Children { | ||||
| get { return children_; } | get { return children_; } | ||||
| } | } | ||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> Dependencies { | |||||
| get { return dependencies_; } | |||||
| } | |||||
| /// <summary>Field number for the "slot_variables" field.</summary> | /// <summary>Field number for the "slot_variables" field.</summary> | ||||
| public const int SlotVariablesFieldNumber = 3; | public const int SlotVariablesFieldNumber = 3; | ||||
| @@ -617,6 +624,7 @@ namespace Tensorflow { | |||||
| return; | return; | ||||
| } | } | ||||
| children_.Add(other.children_); | children_.Add(other.children_); | ||||
| dependencies_.Add(other.dependencies_); | |||||
| slotVariables_.Add(other.slotVariables_); | slotVariables_.Add(other.slotVariables_); | ||||
| saveableObjects_.Add(other.saveableObjects_); | saveableObjects_.Add(other.saveableObjects_); | ||||
| switch (other.KindCase) { | switch (other.KindCase) { | ||||
| @@ -198,6 +198,22 @@ namespace Tensorflow { | |||||
| public TrackableObject() { | public TrackableObject() { | ||||
| OnConstruction(); | OnConstruction(); | ||||
| } | } | ||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public TrackableObject(pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference> slot) { | |||||
| OnConstruction(); | |||||
| slotVariables_ = slot; | |||||
| } | |||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public TrackableObject(pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference> slot, | |||||
| pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children | |||||
| ) | |||||
| { | |||||
| OnConstruction(); | |||||
| slotVariables_ = slot; | |||||
| children_ = children; | |||||
| } | |||||
| partial void OnConstruction(); | partial void OnConstruction(); | ||||
| @@ -108,6 +108,7 @@ https://tensorflownet.readthedocs.io</Description> | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> | <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> | ||||
| <PackageReference Include="Newtonsoft.Json" Version="13.0.2" /> | |||||
| <PackageReference Include="Protobuf.Text" Version="0.6.0" /> | <PackageReference Include="Protobuf.Text" Version="0.6.0" /> | ||||
| <PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" /> | <PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| @@ -202,6 +202,24 @@ namespace Tensorflow | |||||
| _ => type.ToString() | _ => type.ToString() | ||||
| }; | }; | ||||
| public static string as_python_name(this TF_DataType type) | |||||
| => type switch | |||||
| { | |||||
| TF_DataType.TF_STRING => "str", | |||||
| TF_DataType.TF_UINT8 => "uint8", | |||||
| TF_DataType.TF_INT8 => "int8", | |||||
| TF_DataType.TF_UINT32 => "uint32", | |||||
| TF_DataType.TF_INT32 => "int32", | |||||
| TF_DataType.TF_UINT64 => "uint64", | |||||
| TF_DataType.TF_INT64 => "int64", | |||||
| TF_DataType.TF_FLOAT => "float32", | |||||
| TF_DataType.TF_DOUBLE => "float64", | |||||
| TF_DataType.TF_BOOL => "bool", | |||||
| TF_DataType.TF_RESOURCE => "resource", | |||||
| TF_DataType.TF_VARIANT => "variant", | |||||
| _ => type.ToString() | |||||
| }; | |||||
| public static int get_datatype_size(this TF_DataType type) | public static int get_datatype_size(this TF_DataType type) | ||||
| => type.as_base_dtype() switch | => type.as_base_dtype() switch | ||||
| { | { | ||||
| @@ -1,6 +1,71 @@ | |||||
| namespace Tensorflow.Train | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using Tensorflow.Functions; | |||||
| using Tensorflow.Keras.Saving.SavedModel; | |||||
| using Tensorflow.Operations.Activation; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Train | |||||
| { | { | ||||
| public abstract class AutoTrackable : Trackable | |||||
| public class AutoTrackable : Trackable | |||||
| { | { | ||||
| public void _delete_tracking(string name) | |||||
| { | |||||
| _maybe_initialize_trackable(); | |||||
| if (_unconditional_dependency_names.ContainsKey(name)) | |||||
| { | |||||
| _unconditional_dependency_names.Remove(name); | |||||
| for (int i = _unconditional_checkpoint_dependencies.Count - 1; i >= 0; i--) | |||||
| { | |||||
| if (_unconditional_checkpoint_dependencies[i].Name == name) | |||||
| { | |||||
| _unconditional_checkpoint_dependencies.RemoveAt(i); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||||
| { | |||||
| if(save_type != SaveType.SAVEDMODEL) | |||||
| { | |||||
| return base._trackable_children(save_type, cache); | |||||
| } | |||||
| Dictionary<string, Trackable> functions = new(); | |||||
| // TODO: process of logs. | |||||
| var properties = this.GetType().GetProperties(); | |||||
| foreach ( var property in properties ) | |||||
| { | |||||
| if(property.PropertyType == typeof(Function) || property.PropertyType == typeof(ConcreteFunction)) | |||||
| { | |||||
| string name = property.Name; | |||||
| object value = property.GetValue(this, null); | |||||
| functions[name] = (Trackable)value; | |||||
| } | |||||
| } | |||||
| // TODO: process the type `core_types.GenericFunction`. | |||||
| Dictionary<string, Trackable> children = new(); | |||||
| foreach(var pair in CheckpointDependencies) | |||||
| { | |||||
| var name = pair.Name; | |||||
| var child = pair.Refer; | |||||
| if(child is ConcreteFunction) // or Generic function | |||||
| { | |||||
| continue; | |||||
| } | |||||
| if(functions.ContainsKey(name) && functions[name] != child) | |||||
| { | |||||
| throw new ValueError($"Can't save object because it has multiple children with the same " + | |||||
| $"name. Object: {this}, attribute name: {name}, child 1: " + | |||||
| $"{child}, child 2: {functions[name]}"); | |||||
| } | |||||
| children[name] = child; | |||||
| } | |||||
| return children.Concat(functions).ToDictionary(x => x.Key, x => x.Value); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,12 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Train; | |||||
| namespace Tensorflow.Training | |||||
| { | |||||
| public interface IWithTrackable | |||||
| { | |||||
| Trackable GetTrackable(); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,9 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Train; | |||||
| namespace Tensorflow.Training | |||||
| { | |||||
| } | |||||
| @@ -351,7 +351,7 @@ namespace Tensorflow | |||||
| /// <param name="var"></param> | /// <param name="var"></param> | ||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| protected IVariableV1 get_slot(IVariableV1 var, string name) | |||||
| internal IVariableV1 get_slot(IVariableV1 var, string name) | |||||
| { | { | ||||
| var named_slots = _slots.ContainsKey(name) ? _slots[name] : null; | var named_slots = _slots.ContainsKey(name) ? _slots[name] : null; | ||||
| if (named_slots == null) | if (named_slots == null) | ||||
| @@ -360,6 +360,11 @@ namespace Tensorflow | |||||
| return named_slots.ContainsKey(_var_key(var)) ? named_slots[_var_key(var)] : null; | return named_slots.ContainsKey(_var_key(var)) ? named_slots[_var_key(var)] : null; | ||||
| } | } | ||||
| internal IEnumerable<string> get_slot_names() | |||||
| { | |||||
| return _slots.Keys; | |||||
| } | |||||
| private string _var_key(IVariableV1 var) | private string _var_key(IVariableV1 var) | ||||
| { | { | ||||
| return $"{var.Op.graph.graph_key}.{var.Op.name}"; | return $"{var.Op.graph.graph_key}.{var.Op.name}"; | ||||
| @@ -14,6 +14,8 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class ResourceVariableSaveable : MySaveableObject | public class ResourceVariableSaveable : MySaveableObject | ||||
| @@ -35,6 +37,32 @@ namespace Tensorflow | |||||
| this.name = name; | this.name = name; | ||||
| } | } | ||||
| public ResourceVariableSaveable(BaseResourceVariable var, string slice_spec, string name) | |||||
| { | |||||
| _var_device = var.Device; | |||||
| _var_shape = var.shape; | |||||
| Tensor _read_variable_closure(BaseResourceVariable v) | |||||
| { | |||||
| tf.device(v.Device); | |||||
| if(tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) | |||||
| { | |||||
| return null; | |||||
| } | |||||
| var x = v.read_value_no_copy(); | |||||
| tf.device("/device:CPU:0"); | |||||
| return array_ops.identity(x); | |||||
| } | |||||
| this.handle_op = var.Handle; | |||||
| var tensor = _read_variable_closure(var); | |||||
| var spec = new SaveSpec(tensor, slice_spec, name, dtype: var.dtype); | |||||
| _op = var; | |||||
| specs = new SaveSpec[] { spec }; | |||||
| this.name = name; | |||||
| } | |||||
| public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) | public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) | ||||
| { | { | ||||
| var restored_tensor = restored_tensors[0]; | var restored_tensor = restored_tensors[0]; | ||||
| @@ -28,7 +28,7 @@ namespace Tensorflow | |||||
| public string slice_spec => _slice_spec; | public string slice_spec => _slice_spec; | ||||
| private string _name; | private string _name; | ||||
| public string name => _name; | |||||
| public string name { get => _name; set => _name = value; } | |||||
| private TF_DataType _dtype; | private TF_DataType _dtype; | ||||
| public TF_DataType dtype => _dtype; | public TF_DataType dtype => _dtype; | ||||