Browse Source

Fix the error that loaded concrete function does not work.

tags/v0.100.5-BERT-load
Yaohui Liu 2 years ago
parent
commit
4252952208
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
4 changed files with 32 additions and 17 deletions
  1. +26
    -10
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  2. +2
    -1
      src/TensorFlowNET.Keras/Engine/Sequential.cs
  3. +2
    -5
      src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs
  4. +2
    -1
      test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs

+ 26
- 10
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -6,6 +6,7 @@ using Tensorflow.Eager;
using Tensorflow.Framework.Models;
using Tensorflow.Graphs;
using Tensorflow.Train;
using Tensorflow.Util;
using static Tensorflow.Binding;

namespace Tensorflow.Functions
@@ -21,6 +22,7 @@ namespace Tensorflow.Functions
protected Dictionary<string, string> _attrs;
protected FunctionSpec _function_spec;
protected FunctionSpec _pre_initialized_function_spec = null;
protected EagerDefinedFunction _inference_function;
internal ForwardBackwardCall forward_backward;
public Tensor[] Inputs => func_graph.Inputs;
public Tensor[] CapturedInputs => func_graph.external_captures;
@@ -39,6 +41,7 @@ namespace Tensorflow.Functions
_captured_inputs = func_graph.external_captures;
_attrs= new Dictionary<string, string>();
_delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs);
_inference_function = _delayed_rewrite_functions.Forward();
}

public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs = null)
@@ -49,6 +52,7 @@ namespace Tensorflow.Functions
//ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray());
_attrs = attrs;
_delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs);
_inference_function = _delayed_rewrite_functions.Forward();
}

public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype)
@@ -69,6 +73,7 @@ namespace Tensorflow.Functions
_captured_inputs = func_graph.external_captures;
_attrs = new Dictionary<string, string>();
_delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs);
_inference_function = _delayed_rewrite_functions.Forward();
}

public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype)
@@ -92,6 +97,7 @@ namespace Tensorflow.Functions
_captured_inputs = func_graph.external_captures;
_attrs = new Dictionary<string, string>();
_delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs);
_inference_function = _delayed_rewrite_functions.Forward();
}

/*public ConcreteFunction(Func<Tensors, Tensors> func,
@@ -154,9 +160,10 @@ namespace Tensorflow.Functions
{
tensor_inputs.Add(arg);
// If we're graph building, shape inference is on.
if (!executing_eagerly)
{
}
}
if (!executing_eagerly)
{

}
tensor_inputs.AddRange(captured_inputs);

@@ -166,12 +173,13 @@ namespace Tensorflow.Functions
// No tape is watching; skip to running the function.
if (possible_gradient_type == 0 && executing_eagerly)
{
var attrs = new object[]
{
"executor_type", "",
"config_proto", tf.Context.FunctionCallOptions.config_proto_serialized()
};
return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs);
return _build_call_outputs(_inference_function.Call(args));
//var attrs = new object[]
//{
// "executor_type", "",
// "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized()
//};
//return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs);
}

if (forward_backward == null)
@@ -184,10 +192,11 @@ namespace Tensorflow.Functions
}
else
{
// TODO(Rinne): add `default_graph._override_gradient_function`.
flat_outputs = forward_function.Call(args_with_tangents);
}
forward_backward.Record(flat_outputs);
return flat_outputs;
return _build_call_outputs(flat_outputs);
}

public void AddTograph(Graph? g = null)
@@ -262,6 +271,13 @@ namespace Tensorflow.Functions
};
}

private Tensors _build_call_outputs(Tensors result)
{
// TODO(Rinne): dwal with `func_graph.structured_outputs`

return result;
}

public override string ToString()
=> Name;
}


+ 2
- 1
src/TensorFlowNET.Keras/Engine/Sequential.cs View File

@@ -124,11 +124,12 @@ namespace Tensorflow.Keras.Engine
if (set_inputs || _is_graph_network)
{
_init_graph_network(inputs, outputs);
_is_graph_network = true;
_graph_initialized = true;
}
else
{
_self_tracked_trackables.add(layer);
// TODO(Rinne): self._handle_deferred_layer_dependencies([layer])
}
}



+ 2
- 5
src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs View File

@@ -189,11 +189,8 @@ namespace Tensorflow.Keras.Saving
}

// `model.__init__(layers, config["name"])`InitLayers(layers);
s = new Sequential(new SequentialArgs(){
Layers = layers.Select(x => x as ILayer).ToList(),
Name = config["name"].ToObject<string>()
});
//s.Name = config["name"].ToObject<string>();
s.InitLayers(layers.Select(x => x as ILayer));
s.Name = config["name"].ToObject<string>();
if(s.inputs is null || s.inputs.Length == 0)
{
var first_layer = _get_child_layer_node_ids(model_id)[0];


+ 2
- 1
test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs View File

@@ -62,10 +62,11 @@ public class SequentialModelLoad
[TestMethod]
public void Temp()
{
var model = tf.keras.models.load_model(@"C:\Work\tf.net\tf_test\python_func");
var model = tf.keras.models.load_model(@"D:\development\tf.net\tf_test\python_func");
model.summary();

var x = tf.ones((2, 10));
var y = model.Apply(x);
Console.WriteLine(y);
}
}

Loading…
Cancel
Save