Browse Source

Resolve some wrong implementations.

tags/v0.100.5-BERT-load
Yaohui Liu 2 years ago
parent
commit
6a9ccea29f
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
15 changed files with 114 additions and 62 deletions
  1. +27
    -0
      src/TensorFlowNET.Core/Buffers/TF_Buffer.cs
  2. +4
    -0
      src/TensorFlowNET.Core/Eager/execute.cs
  3. +1
    -0
      src/TensorFlowNET.Core/Framework/importer.cs
  4. +28
    -8
      src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Graphs/FuncGraph.cs
  6. +13
    -0
      src/TensorFlowNET.Core/Operations/Operation.cs
  7. +2
    -26
      src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs
  8. +10
    -0
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  9. +8
    -2
      src/TensorFlowNET.Core/Variables/ResourceVariable.cs
  10. +4
    -6
      src/TensorFlowNET.Core/Variables/UninitializedVariable.cs
  11. +0
    -6
      src/TensorFlowNET.Keras/Engine/Model.cs
  12. +0
    -5
      src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs
  13. +6
    -4
      src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs
  14. +8
    -2
      src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs
  15. +2
    -2
      test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs

+ 27
- 0
src/TensorFlowNET.Core/Buffers/TF_Buffer.cs View File

@@ -25,5 +25,32 @@ namespace Tensorflow
public IntPtr data; public IntPtr data;
public ulong length; public ulong length;
public IntPtr data_deallocator; public IntPtr data_deallocator;

public unsafe Span<T> AsSpan<T>() where T: unmanaged
{
if(length > int.MaxValue)
{
throw new ValueError($"The length {length} is too large to use in the span.");
}
return new Span<T>(data.ToPointer(), (int)length);
}

public unsafe byte[] ToByteArray()
{
byte[] res = new byte[length];
if(length > int.MaxValue)
{
byte* root = (byte*)data;
for(ulong i = 0; i < length; i++)
{
res[i] = *(root++);
}
}
else
{
new Span<byte>(data.ToPointer(), (int)length).CopyTo(res.AsSpan());
}
return res;
}
} }
} }

+ 4
- 0
src/TensorFlowNET.Core/Eager/execute.cs View File

@@ -18,6 +18,10 @@ namespace Tensorflow.Eager
var types = v.Select(t => t.dtype.as_datatype_enum()); var types = v.Select(t => t.dtype.as_datatype_enum());
return (types.ToArray(), v.ToArray()); return (types.ToArray(), v.ToArray());
} }
public static Tensor[] executes(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null)
{
return quick_execute(op_name, num_outputs, inputs, attrs, ctx, name);
}
public static Tensor[] quick_execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null) public static Tensor[] quick_execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null)
{ {
string device_name = ctx.DeviceName; string device_name = ctx.DeviceName;


+ 1
- 0
src/TensorFlowNET.Core/Framework/importer.cs View File

@@ -149,6 +149,7 @@ namespace Tensorflow
foreach (var new_op in graph._add_new_tf_operations()) foreach (var new_op in graph._add_new_tf_operations())
{ {
var original_device = new_op.Device; var original_device = new_op.Device;
new_op._set_device(original_device);
} }
} }




+ 28
- 8
src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs View File

@@ -1,9 +1,11 @@
using Google.Protobuf; using Google.Protobuf;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using Tensorflow.Contexts; using Tensorflow.Contexts;
using Tensorflow.Eager;
using Tensorflow.Graphs; using Tensorflow.Graphs;
using Tensorflow.Operations; using Tensorflow.Operations;
using Tensorflow.Util; using Tensorflow.Util;
@@ -16,6 +18,8 @@ namespace Tensorflow.Functions
public int _num_outputs; public int _num_outputs;
FuncGraph _func_graph; FuncGraph _func_graph;
FunctionDef _definition; FunctionDef _definition;
OpDef _signature;
string _name;
Tensor[] _func_graph_outputs; Tensor[] _func_graph_outputs;
public string Name => _func_graph.FuncName; public string Name => _func_graph.FuncName;
public DataType[] OutputTypes { get; protected set; } public DataType[] OutputTypes { get; protected set; }
@@ -31,6 +35,18 @@ namespace Tensorflow.Functions
return _definition; return _definition;
} }
} }

public OpDef Signature
{
get
{
if( _signature is null)
{
_signature = Definition.Signature;
}
return _signature;
}
}
public EagerDefinedFunction(string name, FuncGraph graph, public EagerDefinedFunction(string name, FuncGraph graph,
Tensors inputs, Tensors outputs, Tensors inputs, Tensors outputs,
Dictionary<string, string> attrs) Dictionary<string, string> attrs)
@@ -75,12 +91,12 @@ namespace Tensorflow.Functions
Tensor[] outputs; Tensor[] outputs;
if (executing_eagerly) if (executing_eagerly)
{ {
outputs = tf.Runner.TFE_Execute(tf.Context,
tf.Context.DeviceName,
_func_graph.FuncName,
args,
attrs,
_num_outputs);
outputs = execute.executes(
Signature.Name,
_num_outputs,
args,
attrs,
tf.Context);
} }
else else
{ {
@@ -135,9 +151,13 @@ namespace Tensorflow.Functions
private FunctionDef _get_definition() private FunctionDef _get_definition()
{ {
var buffer = c_api_util.tf_buffer(); var buffer = c_api_util.tf_buffer();
// TODO(Rinne): pywrap_tf_session.TF_FunctionToFunctionDef
Status status = new();
c_api.TF_FunctionToFunctionDef(_func_graph._func_graph_handle, buffer, status);
status.Check(true);
var proto_data = c_api.TF_GetBuffer(buffer); var proto_data = c_api.TF_GetBuffer(buffer);
throw new NotImplementedException();
FunctionDef function_def = new();
function_def.MergeFrom(proto_data.AsSpan<byte>());
return function_def;
} }
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Graphs/FuncGraph.cs View File

@@ -10,7 +10,7 @@ namespace Tensorflow.Graphs;
/// </summary> /// </summary>
public class FuncGraph : Graph, IDisposable public class FuncGraph : Graph, IDisposable
{ {
SafeFuncGraphHandle _func_graph_handle;
internal SafeFuncGraphHandle _func_graph_handle;
public string FuncName => _graph_key; public string FuncName => _graph_key;


public Tensors Inputs { get; set; } = new Tensors(); public Tensors Inputs { get; set; } = new Tensors();


+ 13
- 0
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -238,6 +238,19 @@ namespace Tensorflow
return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s); return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s);
} }


[Obsolete("The implementation is not complete.")]
internal void _set_device_from_string(string device_str)
{
// TODO(Rinne): complete it with new C API `SetRequestedDevice`.
//c_api.TF_SetDevice(_handle, device_str);
}

[Obsolete("The implementation is not complete.")]
internal void _set_device(string device)
{
_set_device_from_string(device);
}

private NodeDef GetNodeDef() private NodeDef GetNodeDef()
{ {
var buffer = new Buffer(); var buffer = new Buffer();


+ 2
- 26
src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs View File

@@ -45,11 +45,8 @@ namespace Tensorflow
_asset_file_def = meta_graph.AssetFileDef; _asset_file_def = meta_graph.AssetFileDef;
_operation_attributes = meta_graph.GraphDef.Node.ToDictionary(x => x.Name, x => x.Attr); _operation_attributes = meta_graph.GraphDef.Node.ToDictionary(x => x.Name, x => x.Attr);
_proto = object_graph_proto; _proto = object_graph_proto;
// Debug(Rinne)
var temp = _proto.ToString();
_export_dir = export_dir; _export_dir = export_dir;
// TODO: `this._concrete_functions` and `this._restored_concrete_functions`
// TODO(Rinne): This method is very slow, needs to be accelareted.
// TODO(Rinne): This method is a bit slow (especially under debug mode), may need to be accelareted.
_concrete_functions = function_deserialization.load_function_def_library( _concrete_functions = function_deserialization.load_function_def_library(
meta_graph.GraphDef.Library, _proto); meta_graph.GraphDef.Library, _proto);
_restored_concrete_functions = new HashSet<string>(); _restored_concrete_functions = new HashSet<string>();
@@ -322,11 +319,6 @@ namespace Tensorflow
foreach(var (node_id, proto) in _iter_all_nodes()) foreach(var (node_id, proto) in _iter_all_nodes())
{ {
var node = get(node_id); var node = get(node_id);
if(node is null)
{
// skip it because now we skip the restoration of `Function` and `ConcreteFunction`.
continue;
}
if(proto.SaveableObjects.Keys.Count == 1 && proto.SaveableObjects.First().Key == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) if(proto.SaveableObjects.Keys.Count == 1 && proto.SaveableObjects.First().Key == TrackableUtils.SERIALIZE_TO_TENSORS_NAME)
{ {
// Restore Trackable serialize- and restore-from-tensor functions. // Restore Trackable serialize- and restore-from-tensor functions.
@@ -390,7 +382,7 @@ namespace Tensorflow
var optimizer_object = nodes[optimizer_node_id]; var optimizer_object = nodes[optimizer_node_id];
var optimizer_variable = nodes[slot_variable_proto.OriginalVariableNodeId]; var optimizer_variable = nodes[slot_variable_proto.OriginalVariableNodeId];


// TODO: implement it.
// TODO(Rinne): implement it.
throw new NotImplementedException("The model loading of SavedModel still has some incompleted part." + throw new NotImplementedException("The model loading of SavedModel still has some incompleted part." +
" Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); " Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues.");
} }
@@ -508,21 +500,11 @@ namespace Tensorflow
/// <param name="node_id"></param> /// <param name="node_id"></param>
private void _add_object_graph_edges(SavedObject proto, int node_id) private void _add_object_graph_edges(SavedObject proto, int node_id)
{ {
// Debug(Rinne)
if(node_id == 1)
{
Console.WriteLine();
}
var obj = _nodes[node_id]; var obj = _nodes[node_id];
var setter = _node_setters[node_id]; var setter = _node_setters[node_id];


foreach(var refer in proto.Children) foreach(var refer in proto.Children)
{ {
if(obj is null)
{
// skip it because now we skip the restoration of `Function` and `ConcreteFunction`.
continue;
}
setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]); setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]);
// TODO(Rinne): deal with "__call__" // TODO(Rinne): deal with "__call__"
} }
@@ -553,12 +535,6 @@ namespace Tensorflow
private (Trackable, Action<object, object, object>) _recreate(SavedObject proto, int node_id, IDictionary<int, Trackable> nodes) private (Trackable, Action<object, object, object>) _recreate(SavedObject proto, int node_id, IDictionary<int, Trackable> nodes)
{ {
// skip the registered classes. // skip the registered classes.
if(node_id == 16)
{
// Debug(Rinne)
Console.WriteLine();
}

Dictionary<OneOf<string, int>, Trackable> dependencies = new(); Dictionary<OneOf<string, int>, Trackable> dependencies = new();
foreach(var item in _get_node_dependencies(proto)) foreach(var item in _get_node_dependencies(proto))
{ {


+ 10
- 0
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -65,6 +65,8 @@ namespace Tensorflow
} }


public void __init__(bool trainable = true, public void __init__(bool trainable = true,
Shape shape = null,
TF_DataType dtype = TF_DataType.DtInvalid,
Tensor handle = null, Tensor handle = null,
string name = null, string name = null,
string unique_id = null, string unique_id = null,
@@ -75,6 +77,14 @@ namespace Tensorflow
_unique_id = unique_id; _unique_id = unique_id;
this.handle = handle; this.handle = handle;
_name = name; _name = name;
if(shape is not null)
{
_shape = shape;
}
if(dtype != TF_DataType.DtInvalid)
{
_dtype = dtype;
}


// After the handle has been created, set up a way to clean it up when // After the handle has been created, set up a way to clean it up when
// executing eagerly. We'll hold the only reference to the deleter, so that // executing eagerly. We'll hold the only reference to the deleter, so that


+ 8
- 2
src/TensorFlowNET.Core/Variables/ResourceVariable.cs View File

@@ -116,7 +116,11 @@ namespace Tensorflow
} }
}); });


_shape = shape ?? _initial_value.shape;
if(shape is null)
{
shape = _initial_value.shape;
}
dtype = _initial_value.dtype;


if (_in_graph_mode) if (_in_graph_mode)
{ {
@@ -135,7 +139,7 @@ namespace Tensorflow
{ {
handle = resource_variable_ops.eager_safe_variable_handle( handle = resource_variable_ops.eager_safe_variable_handle(
initial_value: _initial_value, initial_value: _initial_value,
shape: _shape,
shape: shape,
shared_name: shared_name, shared_name: shared_name,
name: name, name: name,
graph_mode: _in_graph_mode); graph_mode: _in_graph_mode);
@@ -154,6 +158,8 @@ namespace Tensorflow
} }


base.__init__(trainable: trainable, base.__init__(trainable: trainable,
shape: shape,
dtype: dtype,
handle: handle, handle: handle,
name: name, name: name,
unique_id: unique_id, unique_id: unique_id,


+ 4
- 6
src/TensorFlowNET.Core/Variables/UninitializedVariable.cs View File

@@ -50,9 +50,9 @@ namespace Tensorflow.Variables
{ {
tf_with(ops.name_scope("Read"), _ => tf_with(ops.name_scope("Read"), _ =>
{ {
tf.device(handle.Device);
var value = gen_resource_variable_ops.read_variable_op(handle, dtype);
resource_variable_ops._maybe_set_handle_data(dtype, handle, value);
tf.device(created_handle.Device);
var value = gen_resource_variable_ops.read_variable_op(created_handle, dtype);
resource_variable_ops._maybe_set_handle_data(dtype, created_handle, value);
_graph_element = value; _graph_element = value;
}); });
ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this); ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this);
@@ -63,9 +63,7 @@ namespace Tensorflow.Variables
} }
}); });
}); });
_shape = shape;
_dtype = dtype;
base.__init__(trainable, created_handle, unique_id: unique_id, handle_name: handle_name);
base.__init__(trainable, shape, dtype, created_handle, unique_id: unique_id, handle_name: handle_name);
} }
} }
} }

+ 0
- 6
src/TensorFlowNET.Keras/Engine/Model.cs View File

@@ -199,11 +199,5 @@ namespace Tensorflow.Keras.Engine
//} //}
base.SetAttr(name, value); base.SetAttr(name, value);
} }


void IModel.set_stopTraining_true()
{
stop_training = true;
}
} }
} }

+ 0
- 5
src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs View File

@@ -307,11 +307,6 @@ namespace Tensorflow.Keras.Saving
private (Trackable, Action<object, object, object>) _load_layer(int node_id, string identifier, string metadata_json) private (Trackable, Action<object, object, object>) _load_layer(int node_id, string identifier, string metadata_json)
{ {
var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json);
// Debug(Rinne)
if(node_id == 11)
{
Console.WriteLine();
}


if (loaded_nodes.ContainsKey(node_id)) if (loaded_nodes.ContainsKey(node_id))
{ {


+ 6
- 4
src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs View File

@@ -124,18 +124,18 @@ public partial class KerasSavedModelUtils
{ {
if (x is ResourceVariable or RefVariable) return (Trackable)x; if (x is ResourceVariable or RefVariable) return (Trackable)x;
else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer.");
}));
}).ToArray());
var trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.TrainableVariables.Select(x => var trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.TrainableVariables.Select(x =>
{ {
if (x is ResourceVariable or RefVariable) return (Trackable)x; if (x is ResourceVariable or RefVariable) return (Trackable)x;
else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer.");
}));
}).ToArray());
var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.NonTrainableVariables.Select(x => var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.NonTrainableVariables.Select(x =>
{ {
if (x is ResourceVariable or RefVariable) return (Trackable)x; if (x is ResourceVariable or RefVariable) return (Trackable)x;
else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer.");
}));
var layers = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable()));
}).ToArray());
var layers = TrackableDataStructure.wrap_or_unwrap(list_all_layers(layer).Select(x => x.GetTrackable()).ToArray());


Dictionary<string, Trackable> res = new(); Dictionary<string, Trackable> res = new();
Debug.Assert(variables is Trackable); Debug.Assert(variables is Trackable);
@@ -158,6 +158,8 @@ public partial class KerasSavedModelUtils
/// <returns></returns> /// <returns></returns>
public static IDictionary<string, Trackable> wrap_layer_functions(Layer layer, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) public static IDictionary<string, Trackable> wrap_layer_functions(Layer layer, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache)
{ {

// high priority
// TODO: deal with type `RevivedLayer` and `Sequential`. // TODO: deal with type `RevivedLayer` and `Sequential`.


// skip the process because of lack of APIs of `Layer`. // skip the process because of lack of APIs of `Layer`.


+ 8
- 2
src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs View File

@@ -121,7 +121,10 @@ namespace Tensorflow.Keras.Saving.SavedModel
} }
else else
{ {
throw new ValueError($"Function {key} missing from serialized function dict.");
// high priority
// TODO(Rinne): complete the implementation.
continue;
//throw new ValueError($"Function {key} missing from serialized function dict.");
} }
} }
return Functions; return Functions;
@@ -151,7 +154,10 @@ namespace Tensorflow.Keras.Saving.SavedModel
} }
else else
{ {
throw new ValueError($"Object {key} missing from serialized object dict.");
// high priority.
// TODO(Rinne): Add the implementation.
continue;
//throw new ValueError($"Object {key} missing from serialized object dict.");
} }
} }
return CheckpointableObjects; return CheckpointableObjects;


+ 2
- 2
test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs View File

@@ -13,12 +13,12 @@ using Tensorflow.Keras;
namespace TensorFlowNET.Keras.UnitTest namespace TensorFlowNET.Keras.UnitTest
{ {
[TestClass] [TestClass]
public class EarltstoppingTest
public class EarlystoppingTest
{ {
[TestMethod] [TestMethod]
// Because loading the weight variable into the model has not yet been implemented, // Because loading the weight variable into the model has not yet been implemented,
// so you'd better not set patience too large, because the weights will equal to the last epoch's weights. // so you'd better not set patience too large, because the weights will equal to the last epoch's weights.
public void Earltstopping()
public void Earlystopping()
{ {
var layers = keras.layers; var layers = keras.layers;
var model = keras.Sequential(new List<ILayer> var model = keras.Sequential(new List<ILayer>


Loading…
Cancel
Save