| @@ -25,5 +25,32 @@ namespace Tensorflow | |||
| public IntPtr data; | |||
| public ulong length; | |||
| public IntPtr data_deallocator; | |||
| public unsafe Span<T> AsSpan<T>() where T: unmanaged | |||
| { | |||
| if(length > int.MaxValue) | |||
| { | |||
| throw new ValueError($"The length {length} is too large to use in the span."); | |||
| } | |||
| return new Span<T>(data.ToPointer(), (int)length); | |||
| } | |||
| public unsafe byte[] ToByteArray() | |||
| { | |||
| byte[] res = new byte[length]; | |||
| if(length > int.MaxValue) | |||
| { | |||
| byte* root = (byte*)data; | |||
| for(ulong i = 0; i < length; i++) | |||
| { | |||
| res[i] = *(root++); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| new Span<byte>(data.ToPointer(), (int)length).CopyTo(res.AsSpan()); | |||
| } | |||
| return res; | |||
| } | |||
| } | |||
| } | |||
| @@ -18,6 +18,10 @@ namespace Tensorflow.Eager | |||
| var types = v.Select(t => t.dtype.as_datatype_enum()); | |||
| return (types.ToArray(), v.ToArray()); | |||
| } | |||
| public static Tensor[] executes(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null) | |||
| { | |||
| return quick_execute(op_name, num_outputs, inputs, attrs, ctx, name); | |||
| } | |||
| public static Tensor[] quick_execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null) | |||
| { | |||
| string device_name = ctx.DeviceName; | |||
| @@ -149,6 +149,7 @@ namespace Tensorflow | |||
| foreach (var new_op in graph._add_new_tf_operations()) | |||
| { | |||
| var original_device = new_op.Device; | |||
| new_op._set_device(original_device); | |||
| } | |||
| } | |||
| @@ -1,9 +1,11 @@ | |||
| using Google.Protobuf; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Contexts; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Graphs; | |||
| using Tensorflow.Operations; | |||
| using Tensorflow.Util; | |||
| @@ -16,6 +18,8 @@ namespace Tensorflow.Functions | |||
| public int _num_outputs; | |||
| FuncGraph _func_graph; | |||
| FunctionDef _definition; | |||
| OpDef _signature; | |||
| string _name; | |||
| Tensor[] _func_graph_outputs; | |||
| public string Name => _func_graph.FuncName; | |||
| public DataType[] OutputTypes { get; protected set; } | |||
| @@ -31,6 +35,18 @@ namespace Tensorflow.Functions | |||
| return _definition; | |||
| } | |||
| } | |||
| public OpDef Signature | |||
| { | |||
| get | |||
| { | |||
| if( _signature is null) | |||
| { | |||
| _signature = Definition.Signature; | |||
| } | |||
| return _signature; | |||
| } | |||
| } | |||
| public EagerDefinedFunction(string name, FuncGraph graph, | |||
| Tensors inputs, Tensors outputs, | |||
| Dictionary<string, string> attrs) | |||
| @@ -75,12 +91,12 @@ namespace Tensorflow.Functions | |||
| Tensor[] outputs; | |||
| if (executing_eagerly) | |||
| { | |||
| outputs = tf.Runner.TFE_Execute(tf.Context, | |||
| tf.Context.DeviceName, | |||
| _func_graph.FuncName, | |||
| args, | |||
| attrs, | |||
| _num_outputs); | |||
| outputs = execute.executes( | |||
| Signature.Name, | |||
| _num_outputs, | |||
| args, | |||
| attrs, | |||
| tf.Context); | |||
| } | |||
| else | |||
| { | |||
| @@ -135,9 +151,13 @@ namespace Tensorflow.Functions | |||
| private FunctionDef _get_definition() | |||
| { | |||
| var buffer = c_api_util.tf_buffer(); | |||
| // TODO(Rinne): pywrap_tf_session.TF_FunctionToFunctionDef | |||
| Status status = new(); | |||
| c_api.TF_FunctionToFunctionDef(_func_graph._func_graph_handle, buffer, status); | |||
| status.Check(true); | |||
| var proto_data = c_api.TF_GetBuffer(buffer); | |||
| throw new NotImplementedException(); | |||
| FunctionDef function_def = new(); | |||
| function_def.MergeFrom(proto_data.AsSpan<byte>()); | |||
| return function_def; | |||
| } | |||
| } | |||
| } | |||
| @@ -10,7 +10,7 @@ namespace Tensorflow.Graphs; | |||
| /// </summary> | |||
| public class FuncGraph : Graph, IDisposable | |||
| { | |||
| SafeFuncGraphHandle _func_graph_handle; | |||
| internal SafeFuncGraphHandle _func_graph_handle; | |||
| public string FuncName => _graph_key; | |||
| public Tensors Inputs { get; set; } = new Tensors(); | |||
| @@ -238,6 +238,19 @@ namespace Tensorflow | |||
| return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s); | |||
| } | |||
| [Obsolete("The implementation is not complete.")] | |||
| internal void _set_device_from_string(string device_str) | |||
| { | |||
| // TODO(Rinne): complete it with new C API `SetRequestedDevice`. | |||
| //c_api.TF_SetDevice(_handle, device_str); | |||
| } | |||
| [Obsolete("The implementation is not complete.")] | |||
| internal void _set_device(string device) | |||
| { | |||
| _set_device_from_string(device); | |||
| } | |||
| private NodeDef GetNodeDef() | |||
| { | |||
| var buffer = new Buffer(); | |||
| @@ -45,11 +45,8 @@ namespace Tensorflow | |||
| _asset_file_def = meta_graph.AssetFileDef; | |||
| _operation_attributes = meta_graph.GraphDef.Node.ToDictionary(x => x.Name, x => x.Attr); | |||
| _proto = object_graph_proto; | |||
| // Debug(Rinne) | |||
| var temp = _proto.ToString(); | |||
| _export_dir = export_dir; | |||
| // TODO: `this._concrete_functions` and `this._restored_concrete_functions` | |||
| // TODO(Rinne): This method is very slow, needs to be accelareted. | |||
| // TODO(Rinne): This method is a bit slow (especially under debug mode), may need to be accelareted. | |||
| _concrete_functions = function_deserialization.load_function_def_library( | |||
| meta_graph.GraphDef.Library, _proto); | |||
| _restored_concrete_functions = new HashSet<string>(); | |||
| @@ -322,11 +319,6 @@ namespace Tensorflow | |||
| foreach(var (node_id, proto) in _iter_all_nodes()) | |||
| { | |||
| var node = get(node_id); | |||
| if(node is null) | |||
| { | |||
| // skip it because now we skip the restoration of `Function` and `ConcreteFunction`. | |||
| continue; | |||
| } | |||
| if(proto.SaveableObjects.Keys.Count == 1 && proto.SaveableObjects.First().Key == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) | |||
| { | |||
| // Restore Trackable serialize- and restore-from-tensor functions. | |||
| @@ -390,7 +382,7 @@ namespace Tensorflow | |||
| var optimizer_object = nodes[optimizer_node_id]; | |||
| var optimizer_variable = nodes[slot_variable_proto.OriginalVariableNodeId]; | |||
| // TODO: implement it. | |||
| // TODO(Rinne): implement it. | |||
| throw new NotImplementedException("The model loading of SavedModel still has some incompleted part." + | |||
| " Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||
| } | |||
| @@ -508,21 +500,11 @@ namespace Tensorflow | |||
| /// <param name="node_id"></param> | |||
| private void _add_object_graph_edges(SavedObject proto, int node_id) | |||
| { | |||
| // Debug(Rinne) | |||
| if(node_id == 1) | |||
| { | |||
| Console.WriteLine(); | |||
| } | |||
| var obj = _nodes[node_id]; | |||
| var setter = _node_setters[node_id]; | |||
| foreach(var refer in proto.Children) | |||
| { | |||
| if(obj is null) | |||
| { | |||
| // skip it because now we skip the restoration of `Function` and `ConcreteFunction`. | |||
| continue; | |||
| } | |||
| setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]); | |||
| // TODO(Rinne): deal with "__call__" | |||
| } | |||
| @@ -553,12 +535,6 @@ namespace Tensorflow | |||
| private (Trackable, Action<object, object, object>) _recreate(SavedObject proto, int node_id, IDictionary<int, Trackable> nodes) | |||
| { | |||
| // skip the registered classes. | |||
| if(node_id == 16) | |||
| { | |||
| // Debug(Rinne) | |||
| Console.WriteLine(); | |||
| } | |||
| Dictionary<OneOf<string, int>, Trackable> dependencies = new(); | |||
| foreach(var item in _get_node_dependencies(proto)) | |||
| { | |||
| @@ -65,6 +65,8 @@ namespace Tensorflow | |||
| } | |||
| public void __init__(bool trainable = true, | |||
| Shape shape = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| Tensor handle = null, | |||
| string name = null, | |||
| string unique_id = null, | |||
| @@ -75,6 +77,14 @@ namespace Tensorflow | |||
| _unique_id = unique_id; | |||
| this.handle = handle; | |||
| _name = name; | |||
| if(shape is not null) | |||
| { | |||
| _shape = shape; | |||
| } | |||
| if(dtype != TF_DataType.DtInvalid) | |||
| { | |||
| _dtype = dtype; | |||
| } | |||
| // After the handle has been created, set up a way to clean it up when | |||
| // executing eagerly. We'll hold the only reference to the deleter, so that | |||
| @@ -116,7 +116,11 @@ namespace Tensorflow | |||
| } | |||
| }); | |||
| _shape = shape ?? _initial_value.shape; | |||
| if(shape is null) | |||
| { | |||
| shape = _initial_value.shape; | |||
| } | |||
| dtype = _initial_value.dtype; | |||
| if (_in_graph_mode) | |||
| { | |||
| @@ -135,7 +139,7 @@ namespace Tensorflow | |||
| { | |||
| handle = resource_variable_ops.eager_safe_variable_handle( | |||
| initial_value: _initial_value, | |||
| shape: _shape, | |||
| shape: shape, | |||
| shared_name: shared_name, | |||
| name: name, | |||
| graph_mode: _in_graph_mode); | |||
| @@ -154,6 +158,8 @@ namespace Tensorflow | |||
| } | |||
| base.__init__(trainable: trainable, | |||
| shape: shape, | |||
| dtype: dtype, | |||
| handle: handle, | |||
| name: name, | |||
| unique_id: unique_id, | |||
| @@ -50,9 +50,9 @@ namespace Tensorflow.Variables | |||
| { | |||
| tf_with(ops.name_scope("Read"), _ => | |||
| { | |||
| tf.device(handle.Device); | |||
| var value = gen_resource_variable_ops.read_variable_op(handle, dtype); | |||
| resource_variable_ops._maybe_set_handle_data(dtype, handle, value); | |||
| tf.device(created_handle.Device); | |||
| var value = gen_resource_variable_ops.read_variable_op(created_handle, dtype); | |||
| resource_variable_ops._maybe_set_handle_data(dtype, created_handle, value); | |||
| _graph_element = value; | |||
| }); | |||
| ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this); | |||
| @@ -63,9 +63,7 @@ namespace Tensorflow.Variables | |||
| } | |||
| }); | |||
| }); | |||
| _shape = shape; | |||
| _dtype = dtype; | |||
| base.__init__(trainable, created_handle, unique_id: unique_id, handle_name: handle_name); | |||
| base.__init__(trainable, shape, dtype, created_handle, unique_id: unique_id, handle_name: handle_name); | |||
| } | |||
| } | |||
| } | |||
| @@ -199,11 +199,5 @@ namespace Tensorflow.Keras.Engine | |||
| //} | |||
| base.SetAttr(name, value); | |||
| } | |||
| void IModel.set_stopTraining_true() | |||
| { | |||
| stop_training = true; | |||
| } | |||
| } | |||
| } | |||
| @@ -307,11 +307,6 @@ namespace Tensorflow.Keras.Saving | |||
| private (Trackable, Action<object, object, object>) _load_layer(int node_id, string identifier, string metadata_json) | |||
| { | |||
| var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); | |||
| // Debug(Rinne) | |||
| if(node_id == 11) | |||
| { | |||
| Console.WriteLine(); | |||
| } | |||
| if (loaded_nodes.ContainsKey(node_id)) | |||
| { | |||
| @@ -124,18 +124,18 @@ public partial class KerasSavedModelUtils | |||
| { | |||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||
| })); | |||
| }).ToArray()); | |||
| var trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.TrainableVariables.Select(x => | |||
| { | |||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||
| })); | |||
| }).ToArray()); | |||
| var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.NonTrainableVariables.Select(x => | |||
| { | |||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||
| })); | |||
| var layers = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); | |||
| }).ToArray()); | |||
| var layers = TrackableDataStructure.wrap_or_unwrap(list_all_layers(layer).Select(x => x.GetTrackable()).ToArray()); | |||
| Dictionary<string, Trackable> res = new(); | |||
| Debug.Assert(variables is Trackable); | |||
| @@ -158,6 +158,8 @@ public partial class KerasSavedModelUtils | |||
| /// <returns></returns> | |||
| public static IDictionary<string, Trackable> wrap_layer_functions(Layer layer, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) | |||
| { | |||
| // high priority | |||
| // TODO: deal with type `RevivedLayer` and `Sequential`. | |||
| // skip the process because of lack of APIs of `Layer`. | |||
| @@ -121,7 +121,10 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||
| } | |||
| else | |||
| { | |||
| throw new ValueError($"Function {key} missing from serialized function dict."); | |||
| // high priority | |||
| // TODO(Rinne): complete the implementation. | |||
| continue; | |||
| //throw new ValueError($"Function {key} missing from serialized function dict."); | |||
| } | |||
| } | |||
| return Functions; | |||
| @@ -151,7 +154,10 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||
| } | |||
| else | |||
| { | |||
| throw new ValueError($"Object {key} missing from serialized object dict."); | |||
| // high priority. | |||
| // TODO(Rinne): Add the implementation. | |||
| continue; | |||
| //throw new ValueError($"Object {key} missing from serialized object dict."); | |||
| } | |||
| } | |||
| return CheckpointableObjects; | |||
| @@ -13,12 +13,12 @@ using Tensorflow.Keras; | |||
| namespace TensorFlowNET.Keras.UnitTest | |||
| { | |||
| [TestClass] | |||
| public class EarltstoppingTest | |||
| public class EarlystoppingTest | |||
| { | |||
| [TestMethod] | |||
| // Because loading the weight variable into the model has not yet been implemented, | |||
| // so you'd better not set patience too large, because the weights will equal to the last epoch's weights. | |||
| public void Earltstopping() | |||
| public void Earlystopping() | |||
| { | |||
| var layers = keras.layers; | |||
| var model = keras.Sequential(new List<ILayer> | |||