| @@ -25,5 +25,32 @@ namespace Tensorflow | |||||
| public IntPtr data; | public IntPtr data; | ||||
| public ulong length; | public ulong length; | ||||
| public IntPtr data_deallocator; | public IntPtr data_deallocator; | ||||
| public unsafe Span<T> AsSpan<T>() where T: unmanaged | |||||
| { | |||||
| if(length > int.MaxValue) | |||||
| { | |||||
| throw new ValueError($"The length {length} is too large to use in the span."); | |||||
| } | |||||
| return new Span<T>(data.ToPointer(), (int)length); | |||||
| } | |||||
| public unsafe byte[] ToByteArray() | |||||
| { | |||||
| byte[] res = new byte[length]; | |||||
| if(length > int.MaxValue) | |||||
| { | |||||
| byte* root = (byte*)data; | |||||
| for(ulong i = 0; i < length; i++) | |||||
| { | |||||
| res[i] = *(root++); | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| new Span<byte>(data.ToPointer(), (int)length).CopyTo(res.AsSpan()); | |||||
| } | |||||
| return res; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -18,6 +18,10 @@ namespace Tensorflow.Eager | |||||
| var types = v.Select(t => t.dtype.as_datatype_enum()); | var types = v.Select(t => t.dtype.as_datatype_enum()); | ||||
| return (types.ToArray(), v.ToArray()); | return (types.ToArray(), v.ToArray()); | ||||
| } | } | ||||
| public static Tensor[] executes(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null) | |||||
| { | |||||
| return quick_execute(op_name, num_outputs, inputs, attrs, ctx, name); | |||||
| } | |||||
| public static Tensor[] quick_execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null) | public static Tensor[] quick_execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null) | ||||
| { | { | ||||
| string device_name = ctx.DeviceName; | string device_name = ctx.DeviceName; | ||||
| @@ -149,6 +149,7 @@ namespace Tensorflow | |||||
| foreach (var new_op in graph._add_new_tf_operations()) | foreach (var new_op in graph._add_new_tf_operations()) | ||||
| { | { | ||||
| var original_device = new_op.Device; | var original_device = new_op.Device; | ||||
| new_op._set_device(original_device); | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,9 +1,11 @@ | |||||
| using Google.Protobuf; | using Google.Protobuf; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.IO; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Contexts; | using Tensorflow.Contexts; | ||||
| using Tensorflow.Eager; | |||||
| using Tensorflow.Graphs; | using Tensorflow.Graphs; | ||||
| using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| @@ -16,6 +18,8 @@ namespace Tensorflow.Functions | |||||
| public int _num_outputs; | public int _num_outputs; | ||||
| FuncGraph _func_graph; | FuncGraph _func_graph; | ||||
| FunctionDef _definition; | FunctionDef _definition; | ||||
| OpDef _signature; | |||||
| string _name; | |||||
| Tensor[] _func_graph_outputs; | Tensor[] _func_graph_outputs; | ||||
| public string Name => _func_graph.FuncName; | public string Name => _func_graph.FuncName; | ||||
| public DataType[] OutputTypes { get; protected set; } | public DataType[] OutputTypes { get; protected set; } | ||||
| @@ -31,6 +35,18 @@ namespace Tensorflow.Functions | |||||
| return _definition; | return _definition; | ||||
| } | } | ||||
| } | } | ||||
| public OpDef Signature | |||||
| { | |||||
| get | |||||
| { | |||||
| if( _signature is null) | |||||
| { | |||||
| _signature = Definition.Signature; | |||||
| } | |||||
| return _signature; | |||||
| } | |||||
| } | |||||
| public EagerDefinedFunction(string name, FuncGraph graph, | public EagerDefinedFunction(string name, FuncGraph graph, | ||||
| Tensors inputs, Tensors outputs, | Tensors inputs, Tensors outputs, | ||||
| Dictionary<string, string> attrs) | Dictionary<string, string> attrs) | ||||
| @@ -75,12 +91,12 @@ namespace Tensorflow.Functions | |||||
| Tensor[] outputs; | Tensor[] outputs; | ||||
| if (executing_eagerly) | if (executing_eagerly) | ||||
| { | { | ||||
| outputs = tf.Runner.TFE_Execute(tf.Context, | |||||
| tf.Context.DeviceName, | |||||
| _func_graph.FuncName, | |||||
| args, | |||||
| attrs, | |||||
| _num_outputs); | |||||
| outputs = execute.executes( | |||||
| Signature.Name, | |||||
| _num_outputs, | |||||
| args, | |||||
| attrs, | |||||
| tf.Context); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -135,9 +151,13 @@ namespace Tensorflow.Functions | |||||
| private FunctionDef _get_definition() | private FunctionDef _get_definition() | ||||
| { | { | ||||
| var buffer = c_api_util.tf_buffer(); | var buffer = c_api_util.tf_buffer(); | ||||
| // TODO(Rinne): pywrap_tf_session.TF_FunctionToFunctionDef | |||||
| Status status = new(); | |||||
| c_api.TF_FunctionToFunctionDef(_func_graph._func_graph_handle, buffer, status); | |||||
| status.Check(true); | |||||
| var proto_data = c_api.TF_GetBuffer(buffer); | var proto_data = c_api.TF_GetBuffer(buffer); | ||||
| throw new NotImplementedException(); | |||||
| FunctionDef function_def = new(); | |||||
| function_def.MergeFrom(proto_data.AsSpan<byte>()); | |||||
| return function_def; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -10,7 +10,7 @@ namespace Tensorflow.Graphs; | |||||
| /// </summary> | /// </summary> | ||||
| public class FuncGraph : Graph, IDisposable | public class FuncGraph : Graph, IDisposable | ||||
| { | { | ||||
| SafeFuncGraphHandle _func_graph_handle; | |||||
| internal SafeFuncGraphHandle _func_graph_handle; | |||||
| public string FuncName => _graph_key; | public string FuncName => _graph_key; | ||||
| public Tensors Inputs { get; set; } = new Tensors(); | public Tensors Inputs { get; set; } = new Tensors(); | ||||
| @@ -238,6 +238,19 @@ namespace Tensorflow | |||||
| return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s); | return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s); | ||||
| } | } | ||||
| [Obsolete("The implementation is not complete.")] | |||||
| internal void _set_device_from_string(string device_str) | |||||
| { | |||||
| // TODO(Rinne): complete it with new C API `SetRequestedDevice`. | |||||
| //c_api.TF_SetDevice(_handle, device_str); | |||||
| } | |||||
| [Obsolete("The implementation is not complete.")] | |||||
| internal void _set_device(string device) | |||||
| { | |||||
| _set_device_from_string(device); | |||||
| } | |||||
| private NodeDef GetNodeDef() | private NodeDef GetNodeDef() | ||||
| { | { | ||||
| var buffer = new Buffer(); | var buffer = new Buffer(); | ||||
| @@ -45,11 +45,8 @@ namespace Tensorflow | |||||
| _asset_file_def = meta_graph.AssetFileDef; | _asset_file_def = meta_graph.AssetFileDef; | ||||
| _operation_attributes = meta_graph.GraphDef.Node.ToDictionary(x => x.Name, x => x.Attr); | _operation_attributes = meta_graph.GraphDef.Node.ToDictionary(x => x.Name, x => x.Attr); | ||||
| _proto = object_graph_proto; | _proto = object_graph_proto; | ||||
| // Debug(Rinne) | |||||
| var temp = _proto.ToString(); | |||||
| _export_dir = export_dir; | _export_dir = export_dir; | ||||
| // TODO: `this._concrete_functions` and `this._restored_concrete_functions` | |||||
| // TODO(Rinne): This method is very slow, needs to be accelareted. | |||||
| // TODO(Rinne): This method is a bit slow (especially under debug mode), may need to be accelareted. | |||||
| _concrete_functions = function_deserialization.load_function_def_library( | _concrete_functions = function_deserialization.load_function_def_library( | ||||
| meta_graph.GraphDef.Library, _proto); | meta_graph.GraphDef.Library, _proto); | ||||
| _restored_concrete_functions = new HashSet<string>(); | _restored_concrete_functions = new HashSet<string>(); | ||||
| @@ -322,11 +319,6 @@ namespace Tensorflow | |||||
| foreach(var (node_id, proto) in _iter_all_nodes()) | foreach(var (node_id, proto) in _iter_all_nodes()) | ||||
| { | { | ||||
| var node = get(node_id); | var node = get(node_id); | ||||
| if(node is null) | |||||
| { | |||||
| // skip it because now we skip the restoration of `Function` and `ConcreteFunction`. | |||||
| continue; | |||||
| } | |||||
| if(proto.SaveableObjects.Keys.Count == 1 && proto.SaveableObjects.First().Key == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) | if(proto.SaveableObjects.Keys.Count == 1 && proto.SaveableObjects.First().Key == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) | ||||
| { | { | ||||
| // Restore Trackable serialize- and restore-from-tensor functions. | // Restore Trackable serialize- and restore-from-tensor functions. | ||||
| @@ -390,7 +382,7 @@ namespace Tensorflow | |||||
| var optimizer_object = nodes[optimizer_node_id]; | var optimizer_object = nodes[optimizer_node_id]; | ||||
| var optimizer_variable = nodes[slot_variable_proto.OriginalVariableNodeId]; | var optimizer_variable = nodes[slot_variable_proto.OriginalVariableNodeId]; | ||||
| // TODO: implement it. | |||||
| // TODO(Rinne): implement it. | |||||
| throw new NotImplementedException("The model loading of SavedModel still has some incompleted part." + | throw new NotImplementedException("The model loading of SavedModel still has some incompleted part." + | ||||
| " Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | " Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | ||||
| } | } | ||||
| @@ -508,21 +500,11 @@ namespace Tensorflow | |||||
| /// <param name="node_id"></param> | /// <param name="node_id"></param> | ||||
| private void _add_object_graph_edges(SavedObject proto, int node_id) | private void _add_object_graph_edges(SavedObject proto, int node_id) | ||||
| { | { | ||||
| // Debug(Rinne) | |||||
| if(node_id == 1) | |||||
| { | |||||
| Console.WriteLine(); | |||||
| } | |||||
| var obj = _nodes[node_id]; | var obj = _nodes[node_id]; | ||||
| var setter = _node_setters[node_id]; | var setter = _node_setters[node_id]; | ||||
| foreach(var refer in proto.Children) | foreach(var refer in proto.Children) | ||||
| { | { | ||||
| if(obj is null) | |||||
| { | |||||
| // skip it because now we skip the restoration of `Function` and `ConcreteFunction`. | |||||
| continue; | |||||
| } | |||||
| setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]); | setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]); | ||||
| // TODO(Rinne): deal with "__call__" | // TODO(Rinne): deal with "__call__" | ||||
| } | } | ||||
| @@ -553,12 +535,6 @@ namespace Tensorflow | |||||
| private (Trackable, Action<object, object, object>) _recreate(SavedObject proto, int node_id, IDictionary<int, Trackable> nodes) | private (Trackable, Action<object, object, object>) _recreate(SavedObject proto, int node_id, IDictionary<int, Trackable> nodes) | ||||
| { | { | ||||
| // skip the registered classes. | // skip the registered classes. | ||||
| if(node_id == 16) | |||||
| { | |||||
| // Debug(Rinne) | |||||
| Console.WriteLine(); | |||||
| } | |||||
| Dictionary<OneOf<string, int>, Trackable> dependencies = new(); | Dictionary<OneOf<string, int>, Trackable> dependencies = new(); | ||||
| foreach(var item in _get_node_dependencies(proto)) | foreach(var item in _get_node_dependencies(proto)) | ||||
| { | { | ||||
| @@ -65,6 +65,8 @@ namespace Tensorflow | |||||
| } | } | ||||
| public void __init__(bool trainable = true, | public void __init__(bool trainable = true, | ||||
| Shape shape = null, | |||||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||||
| Tensor handle = null, | Tensor handle = null, | ||||
| string name = null, | string name = null, | ||||
| string unique_id = null, | string unique_id = null, | ||||
| @@ -75,6 +77,14 @@ namespace Tensorflow | |||||
| _unique_id = unique_id; | _unique_id = unique_id; | ||||
| this.handle = handle; | this.handle = handle; | ||||
| _name = name; | _name = name; | ||||
| if(shape is not null) | |||||
| { | |||||
| _shape = shape; | |||||
| } | |||||
| if(dtype != TF_DataType.DtInvalid) | |||||
| { | |||||
| _dtype = dtype; | |||||
| } | |||||
| // After the handle has been created, set up a way to clean it up when | // After the handle has been created, set up a way to clean it up when | ||||
| // executing eagerly. We'll hold the only reference to the deleter, so that | // executing eagerly. We'll hold the only reference to the deleter, so that | ||||
| @@ -116,7 +116,11 @@ namespace Tensorflow | |||||
| } | } | ||||
| }); | }); | ||||
| _shape = shape ?? _initial_value.shape; | |||||
| if(shape is null) | |||||
| { | |||||
| shape = _initial_value.shape; | |||||
| } | |||||
| dtype = _initial_value.dtype; | |||||
| if (_in_graph_mode) | if (_in_graph_mode) | ||||
| { | { | ||||
| @@ -135,7 +139,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| handle = resource_variable_ops.eager_safe_variable_handle( | handle = resource_variable_ops.eager_safe_variable_handle( | ||||
| initial_value: _initial_value, | initial_value: _initial_value, | ||||
| shape: _shape, | |||||
| shape: shape, | |||||
| shared_name: shared_name, | shared_name: shared_name, | ||||
| name: name, | name: name, | ||||
| graph_mode: _in_graph_mode); | graph_mode: _in_graph_mode); | ||||
| @@ -154,6 +158,8 @@ namespace Tensorflow | |||||
| } | } | ||||
| base.__init__(trainable: trainable, | base.__init__(trainable: trainable, | ||||
| shape: shape, | |||||
| dtype: dtype, | |||||
| handle: handle, | handle: handle, | ||||
| name: name, | name: name, | ||||
| unique_id: unique_id, | unique_id: unique_id, | ||||
| @@ -50,9 +50,9 @@ namespace Tensorflow.Variables | |||||
| { | { | ||||
| tf_with(ops.name_scope("Read"), _ => | tf_with(ops.name_scope("Read"), _ => | ||||
| { | { | ||||
| tf.device(handle.Device); | |||||
| var value = gen_resource_variable_ops.read_variable_op(handle, dtype); | |||||
| resource_variable_ops._maybe_set_handle_data(dtype, handle, value); | |||||
| tf.device(created_handle.Device); | |||||
| var value = gen_resource_variable_ops.read_variable_op(created_handle, dtype); | |||||
| resource_variable_ops._maybe_set_handle_data(dtype, created_handle, value); | |||||
| _graph_element = value; | _graph_element = value; | ||||
| }); | }); | ||||
| ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this); | ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this); | ||||
| @@ -63,9 +63,7 @@ namespace Tensorflow.Variables | |||||
| } | } | ||||
| }); | }); | ||||
| }); | }); | ||||
| _shape = shape; | |||||
| _dtype = dtype; | |||||
| base.__init__(trainable, created_handle, unique_id: unique_id, handle_name: handle_name); | |||||
| base.__init__(trainable, shape, dtype, created_handle, unique_id: unique_id, handle_name: handle_name); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -199,11 +199,5 @@ namespace Tensorflow.Keras.Engine | |||||
| //} | //} | ||||
| base.SetAttr(name, value); | base.SetAttr(name, value); | ||||
| } | } | ||||
| void IModel.set_stopTraining_true() | |||||
| { | |||||
| stop_training = true; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -307,11 +307,6 @@ namespace Tensorflow.Keras.Saving | |||||
| private (Trackable, Action<object, object, object>) _load_layer(int node_id, string identifier, string metadata_json) | private (Trackable, Action<object, object, object>) _load_layer(int node_id, string identifier, string metadata_json) | ||||
| { | { | ||||
| var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); | var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); | ||||
| // Debug(Rinne) | |||||
| if(node_id == 11) | |||||
| { | |||||
| Console.WriteLine(); | |||||
| } | |||||
| if (loaded_nodes.ContainsKey(node_id)) | if (loaded_nodes.ContainsKey(node_id)) | ||||
| { | { | ||||
| @@ -124,18 +124,18 @@ public partial class KerasSavedModelUtils | |||||
| { | { | ||||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | if (x is ResourceVariable or RefVariable) return (Trackable)x; | ||||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | ||||
| })); | |||||
| }).ToArray()); | |||||
| var trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.TrainableVariables.Select(x => | var trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.TrainableVariables.Select(x => | ||||
| { | { | ||||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | if (x is ResourceVariable or RefVariable) return (Trackable)x; | ||||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | ||||
| })); | |||||
| }).ToArray()); | |||||
| var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.NonTrainableVariables.Select(x => | var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.NonTrainableVariables.Select(x => | ||||
| { | { | ||||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | if (x is ResourceVariable or RefVariable) return (Trackable)x; | ||||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | ||||
| })); | |||||
| var layers = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); | |||||
| }).ToArray()); | |||||
| var layers = TrackableDataStructure.wrap_or_unwrap(list_all_layers(layer).Select(x => x.GetTrackable()).ToArray()); | |||||
| Dictionary<string, Trackable> res = new(); | Dictionary<string, Trackable> res = new(); | ||||
| Debug.Assert(variables is Trackable); | Debug.Assert(variables is Trackable); | ||||
| @@ -158,6 +158,8 @@ public partial class KerasSavedModelUtils | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static IDictionary<string, Trackable> wrap_layer_functions(Layer layer, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) | public static IDictionary<string, Trackable> wrap_layer_functions(Layer layer, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) | ||||
| { | { | ||||
| // high priority | |||||
| // TODO: deal with type `RevivedLayer` and `Sequential`. | // TODO: deal with type `RevivedLayer` and `Sequential`. | ||||
| // skip the process because of lack of APIs of `Layer`. | // skip the process because of lack of APIs of `Layer`. | ||||
| @@ -121,7 +121,10 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| throw new ValueError($"Function {key} missing from serialized function dict."); | |||||
| // high priority | |||||
| // TODO(Rinne): complete the implementation. | |||||
| continue; | |||||
| //throw new ValueError($"Function {key} missing from serialized function dict."); | |||||
| } | } | ||||
| } | } | ||||
| return Functions; | return Functions; | ||||
| @@ -151,7 +154,10 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| throw new ValueError($"Object {key} missing from serialized object dict."); | |||||
| // high priority. | |||||
| // TODO(Rinne): Add the implementation. | |||||
| continue; | |||||
| //throw new ValueError($"Object {key} missing from serialized object dict."); | |||||
| } | } | ||||
| } | } | ||||
| return CheckpointableObjects; | return CheckpointableObjects; | ||||
| @@ -13,12 +13,12 @@ using Tensorflow.Keras; | |||||
| namespace TensorFlowNET.Keras.UnitTest | namespace TensorFlowNET.Keras.UnitTest | ||||
| { | { | ||||
| [TestClass] | [TestClass] | ||||
| public class EarltstoppingTest | |||||
| public class EarlystoppingTest | |||||
| { | { | ||||
| [TestMethod] | [TestMethod] | ||||
| // Because loading the weight variable into the model has not yet been implemented, | // Because loading the weight variable into the model has not yet been implemented, | ||||
| // so you'd better not set patience too large, because the weights will equal to the last epoch's weights. | // so you'd better not set patience too large, because the weights will equal to the last epoch's weights. | ||||
| public void Earltstopping() | |||||
| public void Earlystopping() | |||||
| { | { | ||||
| var layers = keras.layers; | var layers = keras.layers; | ||||
| var model = keras.Sequential(new List<ILayer> | var model = keras.Sequential(new List<ILayer> | ||||