diff --git a/src/TensorFlowNET.Core/APIs/tf.compat.cs b/src/TensorFlowNET.Core/APIs/tf.compat.cs index 5b2b5a10..8a30badd 100644 --- a/src/TensorFlowNET.Core/APIs/tf.compat.cs +++ b/src/TensorFlowNET.Core/APIs/tf.compat.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using Google.Protobuf; using System.Text; namespace Tensorflow @@ -45,6 +46,23 @@ namespace Tensorflow { return as_text(bytes_or_text, encoding); } + + public ByteString as_bytes(ByteString bytes, Encoding encoding = null) + { + return bytes; + } + public ByteString as_bytes(byte[] bytes, Encoding encoding = null) + { + return ByteString.CopyFrom(bytes); + } + public ByteString as_bytes(string text, Encoding encoding = null) + { + if(encoding is null) + { + encoding = Encoding.UTF8; + } + return ByteString.CopyFrom(encoding.GetBytes(text)); + } } public bool executing_eagerly() diff --git a/src/TensorFlowNET.Core/APIs/tf.io.cs b/src/TensorFlowNET.Core/APIs/tf.io.cs index 0c0510dd..be1e86e6 100644 --- a/src/TensorFlowNET.Core/APIs/tf.io.cs +++ b/src/TensorFlowNET.Core/APIs/tf.io.cs @@ -54,6 +54,6 @@ namespace Tensorflow Dictionary input_map = null, string[] return_elements = null, string name = null, - OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name, producer_op_list); + OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name: name, producer_op_list: producer_op_list); } } diff --git a/src/TensorFlowNET.Core/Contexts/Context.cs b/src/TensorFlowNET.Core/Contexts/Context.cs index 21a14831..efb6b0fc 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.cs @@ -156,6 +156,12 @@ namespace Tensorflow.Contexts return has_graph_arg; } + public bool has_function(string name) + { + ensure_initialized(); + return c_api.TFE_ContextHasFunction(_handle, name); + } + public void restore_mode() { context_switches.Pop(); diff --git a/src/TensorFlowNET.Core/Framework/function_def_lib.cs b/src/TensorFlowNET.Core/Framework/function_def_lib.cs new file mode 100644 index 00000000..b81cb71b --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/function_def_lib.cs @@ -0,0 +1,288 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Security.Cryptography; +using System.Text; +using Tensorflow.Graphs; +using static Tensorflow.Binding; +using static Tensorflow.CppShapeInferenceResult.Types; + +namespace Tensorflow.Framework +{ + public class function_def_lib + { + // TODO(Rinne): process signatures and structured outputs. + public static FuncGraph function_def_to_graph(FunctionDef fdef, object? structured_input_signature, + object? structured_outputs, List input_shapes = null) + { + var func_graph = new FuncGraph(fdef.Signature.Name); + if(input_shapes is null) + { + if(fdef.Attr.TryGetValue("_input_shapes", out var input_shapes_attr)) + { + var raw_input_shapes = input_shapes_attr.List.Shape; + input_shapes = new List(); + foreach(var (input_shape, arg_def) in raw_input_shapes.Zip(fdef.Signature.InputArg, (x, y) => (x, y))) + { + if(arg_def.Type == DataType.DtResource && arg_def.HandleData is not null && arg_def.HandleData.Count > 0) + { + input_shapes.Add(null); + } + else + { + input_shapes.Add(input_shape); + } + } + } + } + + var (graph_def, nested_to_flat_tensor_name) = function_def_to_graph_def(fdef, input_shapes); + + func_graph.as_default(); + importer.import_graph_def(graph_def, name: "", validate_colocation_constraints: false); + var input_tensor_names = fdef.Signature.InputArg.Select(x => nested_to_flat_tensor_name[x.Name]); + func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); + + var output_tensor_names = fdef.Signature.OutputArg.Select(x => nested_to_flat_tensor_name[fdef.Ret[x.Name]]); + func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); + // TODO(Rinne): func_graph.ControlOutputs + _set_handle_data(func_graph, fdef); + + foreach(var node in graph_def.Node) + { + if(node.Attr.TryGetValue("_output_shapes", out var output_shapes)) + { + var op = func_graph.get_operation_by_name(node.Name); + foreach(var (output_index, shape) in enumerate(output_shapes.List.Shape.Take(op.outputs.Length))) + { + op.outputs[output_index].shape = new Shape(shape); + } + } + } + Dictionary output_names = new(); + foreach(var (ret_arg_def, tensor_name) in zip(fdef.Signature.OutputArg, output_tensor_names)) + { + output_names[ops.tensor_id(func_graph.get_tensor_by_name(tensor_name))] = ret_arg_def.Name; + } + // TODO(Rinne): func_graph._output_names = output_names + + func_graph.Exit(); + return func_graph; + } + + public static (GraphDef, Dictionary) function_def_to_graph_def(FunctionDef fdef, List input_shapes) + { + var graph_def = new GraphDef() + { + Versions = new VersionDef() + { + Producer = versions.GRAPH_DEF_VERSION, + MinConsumer = versions.GRAPH_DEF_VERSION_MIN_CONSUMER + } + }; + + var default_graph = ops.get_default_graph(); + + if(input_shapes is not null && input_shapes.Count > 0 && input_shapes.Count != fdef.Signature.InputArg.Count) + { + throw new ValueError($"Length of `input_shapes` must match the number " + + $"of `input_arg`s in `fdef`. Got {input_shapes.Count} `input_shapes` and " + + $"{fdef.Signature.InputArg.Count} `input_arg`s."); + } + + foreach(var (i, arg_def) in enumerate(fdef.Signature.InputArg)) + { + NodeDef node_def = new(); + node_def.Name = arg_def.Name; + node_def.Op = "Placeholder"; + node_def.Attr["dtype"] = new AttrValue() + { + Type = arg_def.Type + }; + if(input_shapes is not null && input_shapes.Count > 0 && input_shapes[i] is not null) + { + var input_shape = input_shapes[i]; + // skip the condition that input_shape is not `TensorShapeProto`. + AttrValue shape = new AttrValue() + { + Shape = new TensorShapeProto() + }; + shape.Shape = new TensorShapeProto(input_shape); + node_def.Attr["shape"] = shape; + } + if (!fdef.ArgAttr.ContainsKey((uint)i)) + { + fdef.ArgAttr[(uint)i] = new FunctionDef.Types.ArgAttrs(); + } + var arg_attrs = fdef.ArgAttr[(uint)i].Attr; + foreach(var k in arg_attrs.Keys) + { + if(k == "_output_shapes") + { + if (arg_attrs[k].ValueCase == AttrValue.ValueOneofCase.List) + { + node_def.Attr["shape"].Shape = new TensorShapeProto(arg_attrs[k].List.Shape[0]); + } + else if (arg_attrs[k].ValueCase == AttrValue.ValueOneofCase.Shape) + { + node_def.Attr["shape"].Shape = new TensorShapeProto(arg_attrs[k].Shape); + } + } + else if (k.StartsWith("_")) + { + if (!node_def.Attr.ContainsKey(k)) + { + node_def.Attr[k] = new AttrValue(); + } + node_def.Attr[k] = new AttrValue(arg_attrs[k]); + } + } + + graph_def.Node.Add(node_def); + } + + graph_def.Node.AddRange(fdef.NodeDef); + + Dictionary nested_to_flat_tensor_name = new(); + foreach(var arg_def in fdef.Signature.InputArg) + { + nested_to_flat_tensor_name[arg_def.Name] = $"{arg_def.Name}:0"; + string control_name = "^" + arg_def.Name; + nested_to_flat_tensor_name[control_name] = control_name; + } + + foreach(var node_def in fdef.NodeDef) + { + var graph = default_graph; + // TODO(Rinne): The `Graph` lacks `_functions`, needed to be implemented in the future. + while(graph.OuterGraph is not null) + { + graph = graph.OuterGraph; + } + + var op_def = default_graph.GetOpDef(node_def.Op); + + foreach(var attr in op_def.Attr) + { + if(attr.Type == "func") + { + var fname = node_def.Attr[attr.Name].Func.Name; + if (!is_function(fname)) + { + throw new ValueError($"Function {fname} was not found. Please make sure " + + $"the FunctionDef `fdef` is correct."); + } + } + else if(attr.Type == "list(func)") + { + foreach(var fn in node_def.Attr[attr.Name].List.Func) + { + var fname = fn.Name; + if (!is_function(fname)) + { + throw new ValueError($"Function {fname} was not found. Please make " + + $"sure the FunctionDef `fdef` is correct."); + } + } + } + } + + int flattened_index = 0; + foreach(var arg_def in op_def.OutputArg) + { + var num_args = _get_num_args(arg_def, node_def); + for(int i = 0; i < num_args; i++) + { + var nested_name = $"{node_def.Name}:{arg_def.Name}:{i}"; + var flat_name = $"{node_def.Name}:{flattened_index}"; + nested_to_flat_tensor_name[nested_name] = flat_name; + flattened_index++; + } + } + string control_name = "^" + node_def.Name; + nested_to_flat_tensor_name[control_name] = control_name; + } + + foreach(var node_def in graph_def.Node) + { + for(int i = 0; i < node_def.Input.Count; i++) + { + node_def.Input[i] = nested_to_flat_tensor_name[node_def.Input[i]]; + } + } + + return (graph_def, nested_to_flat_tensor_name); + } + + private static void _set_handle_data(FuncGraph func_graph, FunctionDef fdef) + { + foreach(var (tensor, arg_def) in zip(func_graph.Inputs, fdef.Signature.InputArg).Concat(zip(func_graph.Outputs, fdef.Signature.OutputArg))) + { + if(arg_def.HandleData is not null && arg_def.HandleData.Count > 0) + { + tensor.shape = Shape.Scalar; + + var shape_and_type = arg_def.HandleData[0]; + var handle_data = new HandleData(); + handle_data.IsSet = true; + handle_data.ShapeAndType.Add(new HandleShapeAndType() + { + Shape = shape_and_type.Shape, + Dtype = shape_and_type.Dtype + }); + resource_variable_ops._set_handle_shapes_and_types(tensor, handle_data, true); + } + } + } + + private static long _get_num_args(OpDef.Types.ArgDef arg_def, NodeDef node_def) + { + if (!string.IsNullOrEmpty(arg_def.NumberAttr)) + { + return node_def.Attr[arg_def.NumberAttr].I; + } + else if(!string.IsNullOrEmpty(arg_def.TypeListAttr)) + { + return node_def.Attr[arg_def.TypeListAttr].List.Type.Count; + } + else if(arg_def.TypeAttr is not null || arg_def.Type != DataType.DtInvalid) + { + return 1; + } + else + { + throw new ValueError($"Invalid arg_def:\n\n{arg_def}. Please make sure the " + + $"FunctionDef `fdef` is correct."); + } + } + + public static bool is_function(string fname) + { + if (tf.Context.executing_eagerly()) + { + return tf.Context.has_function(fname); + } + else + { + var graph = ops.get_default_graph(); + while(graph is not null) + { + if (graph.IsFunction(fname)) + { + return true; + } + if(graph.OuterGraph is not null) + { + graph = graph.OuterGraph; + } + else + { + return false; + } + } + } + throw new ValueError("Unexpected behavior happened in runtime, please submit an issue to " + + "https://github.com/SciSharp/TensorFlow.NET/issues"); + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/importer.cs b/src/TensorFlowNET.Core/Framework/importer.cs index 5b99c200..a4e6c72e 100644 --- a/src/TensorFlowNET.Core/Framework/importer.cs +++ b/src/TensorFlowNET.Core/Framework/importer.cs @@ -17,6 +17,7 @@ using Google.Protobuf; using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using static Tensorflow.Binding; using static Tensorflow.OpDef.Types; @@ -25,9 +26,14 @@ namespace Tensorflow { public class importer { + public static ITensorOrOperation[] import_graph_def_for_function(GraphDef graph_def, string name = null) + { + return import_graph_def(graph_def, validate_colocation_constraints: false, name: name); + } public static ITensorOrOperation[] import_graph_def(GraphDef graph_def, Dictionary input_map = null, string[] return_elements = null, + bool validate_colocation_constraints = true, string name = null, OpList producer_op_list = null) { @@ -60,7 +66,7 @@ namespace Tensorflow var scoped_options = c_api_util.ScopedTFImportGraphDefOptions(); var status = new Status(); - _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); + _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements, validate_colocation_constraints ); // need to create a class ImportGraphDefWithResults with IDisposal results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status)); status.Check(true); @@ -73,6 +79,42 @@ namespace Tensorflow return _GatherReturnElements(return_elements, graph, results); } + //private static ITensorOrOperation[] _import_graph_def_internal(GraphDef graph_def, Dictionary input_map = null, string[] return_elements = null, + // bool validate_colocation_constraints = true, string name = null, OpList producer_op_list = null) + //{ + // graph_def = _ProcessGraphDefParam(graph_def); + // input_map = _ProcessInputMapParam(input_map); + // return_elements = _ProcessReturnElementsParam(return_elements); + + // if(producer_op_list is not null) + // { + // _RemoveDefaultAttrs(producer_op_list, graph_def); + // } + + // var graph = ops.get_default_graph(); + // string prefix = null; + // tf_with(ops.name_scope(name, "import", input_map.Values), scope => + // { + // if (scope is not null) + // { + // Debug.Assert(scope.scope_name.EndsWith("/")); + // prefix = scope.scope_name[scope.scope_name.Length - 1].ToString(); + // } + // else + // { + // prefix = ""; + // } + + // input_map = _ConvertInputMapValues(name, input_map); + // }); + + // var scope_options = c_api_util.ScopedTFImportGraphDefOptions(); + // var options = scope_options.Options; + // _PopulateTFImportGraphDefOptions(scope_options, prefix, input_map, return_elements, validate_colocation_constraints); + + + //} + private static ITensorOrOperation[] _GatherReturnElements(string[] requested_return_elements, Graph graph, TF_ImportGraphDefResults results) @@ -113,15 +155,29 @@ namespace Tensorflow public static void _PopulateTFImportGraphDefOptions(ImportGraphDefOptions options, string prefix, Dictionary input_map, - string[] return_elements) + string[] return_elements, + bool validate_colocation_constraints) { c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix); - c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, (char)1); + c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options.Options, true); foreach (var input in input_map) { - var (src_name, src_index) = _ParseTensorName(input.Key); - c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name, src_index, input.Value._as_tf_output()); + var input_src = tf.compat.as_str(input.Key); + var input_dst = input.Value; + if (input_src.StartsWith("^")) + { + var src_name = tf.compat.as_str(input_src.Substring(1)); + var dst_op = input_dst._as_tf_output().oper; + c_api.TF_ImportGraphDefOptionsRemapControlDependency(options.Options, src_name, dst_op); + } + else + { + var (src_name, src_index) = _ParseTensorName(input.Key); + src_name = tf.compat.as_str(src_name); + var dst_output = input_dst._as_tf_output(); + c_api.TF_ImportGraphDefOptionsAddInputMapping(options.Options, src_name, src_index, dst_output); + } } if (return_elements == null) @@ -132,15 +188,16 @@ namespace Tensorflow if (name.Contains(":")) { var (op_name, index) = _ParseTensorName(name); - c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index); + op_name = tf.compat.as_str(op_name); + c_api.TF_ImportGraphDefOptionsAddReturnOutput(options.Options, op_name, index); } else { - c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, name); + c_api.TF_ImportGraphDefOptionsAddReturnOperation(options.Options, tf.compat.as_str(name)); } } - // c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(options, validate_colocation_constraints); + c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(options.Options, validate_colocation_constraints); } private static (string, int) _ParseTensorName(string tensor_name) @@ -173,6 +230,14 @@ namespace Tensorflow return graph_def; } + private static GraphDef _ProcessGraphDefParam(GraphDef graph_def) + { + var old_graph_def = graph_def; + graph_def = new GraphDef(old_graph_def); + + return graph_def; + } + private static void _SetDefaultAttrValues(NodeDef node_def, OpDef op_def) { foreach (var attr_def in op_def.Attr) @@ -240,6 +305,35 @@ namespace Tensorflow } } + private static void _RemoveDefaultAttrs(OpList producer_op_list, GraphDef graph_def) + { + var producer_op_dict = producer_op_list.Op.ToDictionary(x => x.Name, x => x); + + foreach (var node in graph_def.Node) + { + // Remove any default attr values that aren't in op_def. + if (producer_op_dict.ContainsKey(node.Op)) + { + var op_def = op_def_registry.GetOpDef(node.Op); + if(op_def is null) + { + continue; + } + var producer_op_def = producer_op_dict[node.Op]; + foreach (var key in node.Attr.Keys) + { + if (_FindAttrInOpDef(key, op_def) is null) + { + var attr_def = _FindAttrInOpDef(key, producer_op_def); + if (attr_def != null && attr_def.DefaultValue != null && + node.Attr[key] == attr_def.DefaultValue) + node.Attr[key].ClearValue(); + } + } + } + } + } + private static AttrDef _FindAttrInOpDef(string name, OpDef op_def) { return op_def.Attr.FirstOrDefault(x => x.Name == name); diff --git a/src/TensorFlowNET.Core/Framework/versions.cs b/src/TensorFlowNET.Core/Framework/versions.cs new file mode 100644 index 00000000..e91f08a2 --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/versions.cs @@ -0,0 +1,12 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Framework +{ + public class versions + { + public static int GRAPH_DEF_VERSION = 1286; + public static int GRAPH_DEF_VERSION_MIN_CONSUMER = 0; + } +} diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index a6720a5f..23c669b3 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -13,6 +13,7 @@ namespace Tensorflow.Functions /// public class ConcreteFunction: Trackable { + protected IEnumerable _captured_inputs; internal FuncGraph func_graph; internal ForwardBackwardCall forward_backward; public Tensor[] Inputs => func_graph.Inputs; @@ -29,11 +30,13 @@ namespace Tensorflow.Functions public ConcreteFunction(string name) { func_graph = new FuncGraph(name); + _captured_inputs = func_graph.external_captures; } public ConcreteFunction(FuncGraph graph, Dictionary attrs = null) { func_graph = graph; + _captured_inputs = func_graph.external_captures; ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray()); } @@ -53,6 +56,7 @@ namespace Tensorflow.Functions new[] { output }, null); func_graph.Exit(); + _captured_inputs = func_graph.external_captures; } public ConcreteFunction(Func func, TF_DataType dtype) @@ -73,6 +77,7 @@ namespace Tensorflow.Functions new[] { output.variant_tensor }, null); func_graph.Exit(); + _captured_inputs = func_graph.external_captures; } /*public ConcreteFunction(Func func, @@ -174,6 +179,11 @@ namespace Tensorflow.Functions // TODO(Rinne); complete it with `_delayed_rewrite_functions`. } + public void SetExternalCaptures(IEnumerable captures) + { + _captured_inputs = captures; + } + ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) { var functions = new FirstOrderTapeGradientFunctions(func_graph, false); diff --git a/src/TensorFlowNET.Core/Functions/Function.cs b/src/TensorFlowNET.Core/Functions/Function.cs index 056d15f4..45a13632 100644 --- a/src/TensorFlowNET.Core/Functions/Function.cs +++ b/src/TensorFlowNET.Core/Functions/Function.cs @@ -1,4 +1,5 @@ using System; +using Tensorflow.Functions; using Tensorflow.Train; namespace Tensorflow diff --git a/src/TensorFlowNET.Core/Functions/IGenericFunction.cs b/src/TensorFlowNET.Core/Functions/IGenericFunction.cs new file mode 100644 index 00000000..be6a3b2a --- /dev/null +++ b/src/TensorFlowNET.Core/Functions/IGenericFunction.cs @@ -0,0 +1,12 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Functions +{ + public interface IGenericFunction + { + object[] Apply(params object[] args); + ConcreteFunction get_concrete_function(params object[] args); + } +} diff --git a/src/TensorFlowNET.Core/Functions/function_saved_model_utils.cs b/src/TensorFlowNET.Core/Functions/function_saved_model_utils.cs new file mode 100644 index 00000000..c39f2402 --- /dev/null +++ b/src/TensorFlowNET.Core/Functions/function_saved_model_utils.cs @@ -0,0 +1,88 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Operations; +using Tensorflow.Train; +using static Tensorflow.Binding; + +namespace Tensorflow.Functions +{ + public static class function_saved_model_utils + { + /// + /// + /// + /// + /// a list tensors or other objects (such as variables) which + /// contain tensors that were originally captured by the function + public static void restore_captures(ConcreteFunction concrete_function, IEnumerable inputs) + { + var bound_inputs = inputs?.Select(obj => + { + if(obj is Tensor tensor) + { + return get_tensor_from_node(tensor); + } + else if(obj is IVariableV1 variable) + { + return get_tensor_from_node(variable); + } + else + { + throw new TypeError("Encountered an type error, please submit an issue to " + + "https://github.com/SciSharp/TensorFlow.NET/issues"); + } + }); + var bound_variables = inputs.TakeWhile(obj => obj is IVariableV1); + + List captured_inputs_list = new(); + // TODO(Rinne): concrete_function.set_variables(bound_variables) + + + if (bound_inputs is not null) + { + foreach(var (bound_input, internal_capture) in zip(bound_inputs, concrete_function.Inputs.Skip(concrete_function.Inputs.Length - bound_inputs.Count()))) + { + if(hasattr(bound_input, "__tf_experimental_restore_capture__")) + { + throw new NotImplementedException(); + } + else + { + captured_inputs_list.Add(bound_input); + concrete_function.func_graph.replace_capture(bound_input, internal_capture); + if(internal_capture.dtype == dtypes.resource) + { + // skip the check of variable. + handle_data_util.copy_handle_data(bound_input, internal_capture); + } + concrete_function.func_graph.capture(bound_input); + } + } + } + + if(captured_inputs_list.Any(inp => inp is null)) + { + // TODO(Rinne): add warnings. + } + concrete_function.SetExternalCaptures(captured_inputs_list); + } + + public static Tensor get_tensor_from_node(Tensor node) + { + return node; + } + public static Tensor get_tensor_from_node(IVariableV1 node) + { + if (resource_variable_ops.is_resource_variable(node)) + { + return node.Handle; + } + else + { + throw new TypeError("Encountered an type error, please submit an issue to " + + "https://github.com/SciSharp/TensorFlow.NET/issues"); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/custom_gradient.cs b/src/TensorFlowNET.Core/Gradients/custom_gradient.cs new file mode 100644 index 00000000..0a248086 --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/custom_gradient.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Gradients +{ + public class custom_gradient + { + public static string generate_name() + { + return $"CustomGradient-{ops.uid()}"; + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs b/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs index 31cc9c0b..9fe49da2 100644 --- a/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs +++ b/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs @@ -1,6 +1,7 @@ using MethodBoundaryAspect.Fody.Attributes; using System; using System.Collections.Generic; +using System.IO; using System.Linq; using Tensorflow.Eager; using Tensorflow.Functions; @@ -21,8 +22,9 @@ namespace Tensorflow.Graphs public override void OnEntry(MethodExecutionArgs args) { + File.WriteAllText(@"D:\temp\for_test.txt", "jyfgjyfjhfjhc"); // TODO: func_name can be cache in FullName + Args - func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}_{ops.uid_function()}"; + func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}"; if (functions.ContainsKey(func_name)) { diff --git a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs index 3a209b89..333380c4 100644 --- a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs +++ b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs @@ -56,6 +56,11 @@ public class FuncGraph : Graph, IDisposable _handle = handle; } + public void replace_capture(Tensor tensor, Tensor placeholder) + { + _captures[tensor.Id] = (tensor, placeholder); + } + public void ToGraph(Operation[] opers, Tensor[] inputs, Tensor[] outputs, string[] output_names) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 98cad3b2..fccc763e 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -146,6 +146,12 @@ namespace Tensorflow return ops.set_default_graph(this); } + public bool IsFunction(string name) + { + // TODO(Rinne): deal with `_functions`. + throw new NotImplementedException(); + } + private Tensor _as_graph_element(object obj) { if (obj is RefVariable var) diff --git a/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs b/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs index 859465fc..a7ce6ff5 100644 --- a/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs +++ b/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs @@ -28,6 +28,8 @@ public sealed class ImportGraphDefOptions _handle = c_api.TF_NewImportGraphDefOptions(); } + public SafeImportGraphDefOptionsHandle Options => _handle; + public void AddReturnOutput(string name, int index) { c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs index dc1827d8..6221354f 100644 --- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs +++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs @@ -185,6 +185,9 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern void TF_ImportGraphDefOptionsAddReturnOperation(SafeImportGraphDefOptionsHandle opts, string oper_name); + [DllImport(TensorFlowLibName)] + public static extern void TF_ImportGraphDefOptionsSetValidateColocationConstraints(SafeImportGraphDefOptionsHandle options, bool validate_colocation_constraints); + /// /// Add an output in `graph_def` to be returned via the `return_outputs` output /// parameter of TF_GraphImportGraphDef(). If the output is remapped via an input @@ -246,7 +249,7 @@ namespace Tensorflow /// TF_ImportGraphDefOptions* /// unsigned char [DllImport(TensorFlowLibName)] - public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(SafeImportGraphDefOptionsHandle ops, char uniquify_prefix); + public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(SafeImportGraphDefOptionsHandle ops, bool uniquify_prefix); /// /// Fetches the return operations requested via diff --git a/src/TensorFlowNET.Core/Operations/handle_data_util.cs b/src/TensorFlowNET.Core/Operations/handle_data_util.cs new file mode 100644 index 00000000..6d4d8a19 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/handle_data_util.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Eager; + +namespace Tensorflow.Operations +{ + public static class handle_data_util + { + public static void copy_handle_data(Tensor source_t, Tensor target_t) + { + if(target_t.dtype == dtypes.resource || target_t.dtype == dtypes.variant) + { + SafeTensorHandle handle_data; + if(source_t is EagerTensor) + { + handle_data = source_t.Handle; + } + else + { + handle_data = ops.get_resource_handle_data(source_t); + } + throw new NotImplementedException(); + //if(handle_data is not null && handle_data.) + } + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs index 6ce7a0b0..2b1d9a84 100644 --- a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs @@ -126,7 +126,7 @@ namespace Tensorflow /// /// /// - private static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode) + internal static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode) { if (!graph_mode) return; diff --git a/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs b/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs index f2597574..32575213 100644 --- a/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs +++ b/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs @@ -5,6 +5,7 @@ #pragma warning disable 1591, 0612, 3021 #region Designer generated code +using Tensorflow.Framework.Models; using pb = global::Google.Protobuf; using pbc = global::Google.Protobuf.Collections; using pbr = global::Google.Protobuf.Reflection; @@ -2589,9 +2590,17 @@ namespace Tensorflow { } } - #region Nested types - /// Container for nested types declared in the FunctionSpec message type. - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + //public static FunctionSpec from_function_and_signature(string csharp_function, IEnumerable input_signature, bool is_pure = false, object jit_compile = null) + //{ + // // TODO(Rinne): _validate_signature(input_signature) + // // TODO(Rinne): _validate_python_function(python_function, input_signature) + + + //} + + #region Nested types + /// Container for nested types declared in the FunctionSpec message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public static partial class Types { /// /// Whether the function should be compiled by XLA. diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs index d26fe2b5..757e8b7f 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs @@ -1,14 +1,24 @@ -using System; +using Google.Protobuf; +using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; +using System.Runtime.CompilerServices; using System.Text; +using System.Text.RegularExpressions; +using Tensorflow.Framework; using Tensorflow.Functions; +using Tensorflow.Gradients; +using Tensorflow.Graphs; using Tensorflow.Util; +using static Tensorflow.Binding; namespace Tensorflow.Training.Saving.SavedModel { public static class function_deserialization { + private static string _INFERENCE_PREFIX = "__inference_"; + private static string _FUNCTION_WRAPPER_NAME_REGEX = $@"^{_INFERENCE_PREFIX}(.*)_\d+$"; /// /// Creates a `Function` from a `SavedFunction`. /// @@ -22,6 +32,338 @@ namespace Tensorflow.Training.Saving.SavedModel return null; } + public static Dictionary load_function_def_library(FunctionDefLibrary library, + SavedObjectGraph saved_object_graph = null, string load_shared_name_suffix = null, object? wrapper_function = null) + { + var library_function_names = library.Function.Select(x => x.Signature.Name).Distinct(); + Dictionary functions = new(); + Dictionary renamed_functions = new(); + + Graph graph; + if (ops.executing_eagerly_outside_functions()) + { + graph = new Graph(); + } + else + { + graph = ops.get_default_graph(); + } + + if(load_shared_name_suffix is null) + { + load_shared_name_suffix = $"_load_{ops.uid()}"; + } + + Dictionary library_gradient_names = new(); + Dictionary new_gradient_op_types = new(); + Dictionary gradients_to_register = new(); + foreach (var gdef in library.RegisteredGradients) + { + if(gdef.RegisteredOpType is not null) + { + var new_op_type = custom_gradient.generate_name(); + var old_op_type = tf.compat.as_bytes(gdef.RegisteredOpType); + + library_gradient_names[old_op_type] = gdef.GradientFunc; + new_gradient_op_types[old_op_type] = new_op_type; + gradients_to_register[gdef.GradientFunc] = new_op_type; + } + } + + Dictionary> function_deps = new(); + foreach(var fdef in library.Function) + { + function_deps[fdef.Signature.Name] = _list_function_deps(fdef, library_function_names, library_gradient_names); + } + + Dictionary loaded_gradients = new(); + int aa = 0; + var temp = _sort_function_defs(library, function_deps); + foreach (var fdef in temp) + { + aa++; + var orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, new_gradient_op_types); + + if(saved_object_graph is not null && saved_object_graph.ConcreteFunctions.ContainsKey(orig_name)) + { + // TODO(Rinne): implement it. + //var proto = saved_object_graph.ConcreteFunctions[orig_name]; + //throw new NotImplementedException(); + } + + graph.as_default(); + var func_graph = function_def_lib.function_def_to_graph(fdef, null, null); + graph.Exit(); + + _restore_gradient_functions(func_graph, renamed_functions, loaded_gradients); + + foreach(var dep in function_deps[orig_name]) + { + functions[dep].AddTograph(func_graph); + } + + if (fdef.Attr.ContainsKey("_input_shapes")) + { + fdef.Attr.Remove("_input_shapes"); + } + var func = new ConcreteFunction(func_graph, fdef.Attr.ToDictionary(x => x.Key, x => x.Value.S.ToString())); + if(wrapper_function is not null) + { + throw new NotImplementedException(); + } + func.AddTograph(graph); + + functions[orig_name] = func; + renamed_functions[func.Name] = func; + if(func_graph.get_operations().Any(op => op.op.type == "TRTEngineOp")) + { + func.AddTograph(ops.get_default_graph()); + } + + if (gradients_to_register.ContainsKey(orig_name)) + { + var gradient_op_type = gradients_to_register[orig_name]; + loaded_gradients[gradient_op_type] = func; + // TODO(Rinne): deal with gradient registry. + //new RegisteredGradient() { RegisteredOpType = gradient_op_type }. + } + } + return functions; + } + + public static void fix_node_def(NodeDef node_def, IDictionary functions, string shared_name_suffix) + { + if (functions.ContainsKey(node_def.Op)) + { + node_def.Op = functions[node_def.Op].Name; + } + foreach(var attr_value in node_def.Attr.Values) + { + if(attr_value.ValueCase == AttrValue.ValueOneofCase.Func) + { + attr_value.Func.Name = functions[attr_value.Func.Name].Name; + } + else if(attr_value.ValueCase == AttrValue.ValueOneofCase.List) + { + foreach(var fn in attr_value.List.Func) + { + fn.Name = functions[fn.Name].Name; + } + } + } + + if(node_def.Op == "HashTableV2") + { + if(!node_def.Attr.ContainsKey("use_node_name_sharing") || !node_def.Attr["use_node_name_sharing"].B) + { + node_def.Attr["use_node_name_sharing"].B = true; + shared_name_suffix += $"_{ops.uid()}"; + } + } + + var op_def = op_def_registry.GetOpDef(node_def.Op); + if(op_def is not null) + { + var attr = op_def.Attr.Where(x => x.Name == "shared_name").FirstOrDefault(); + if(attr is not null) + { + ByteString shared_name = null; + if(node_def.Attr.ContainsKey("shared_name") && node_def.Attr["shared_name"].S is not null) + { + shared_name = node_def.Attr["shared_name"].S; + } + else if(attr.DefaultValue.S is not null) + { + shared_name = tf.compat.as_bytes(attr.DefaultValue.S); + } + if(shared_name is null) + { + shared_name = tf.compat.as_bytes(node_def.Name); + } + node_def.Attr["shared_name"].S = ByteString.CopyFrom(shared_name.Concat(tf.compat.as_bytes(node_def.Name)).ToArray()); + } + } + } + + private static void _restore_gradient_functions(FuncGraph func_graph, Dictionary renamed_functions, Dictionary loaded_gradients) + { + foreach(var op in func_graph.get_operations()) + { + if(op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall") + { + var function = renamed_functions[tf.compat.as_bytes(op.op.node_def.Attr["f"].Func.Name).ToString()]; + // TODO(Rinne): deal with `op._gradient_function`. + } + string gradient_op_type = null; + try + { + gradient_op_type = op.op.get_attr("_gradient_op_type") as string; + } + catch(Exception e) + { + continue; + } + if (loaded_gradients.ContainsKey(gradient_op_type)) + { + var grad_fn = loaded_gradients[gradient_op_type]; + grad_fn.NumPositionArgs = op.op.inputs.Length; + grad_fn.ArgKeywords = op.op.inputs._inputs.Select(x => x.name); + } + } + } + + private static string _fix_fdef_in_place(FunctionDef fdef, IDictionary functions, string shared_name_suffix, + IDictionary new_gradient_op_types) + { + var orig_name = fdef.Signature.Name; + bool contains_unsaved_custom_gradients = false; + + foreach(var node_def in fdef.NodeDef) + { + fix_node_def(node_def, functions, shared_name_suffix); + var op_type = _get_gradient_op_type(node_def); + if(op_type is not null) + { + if (new_gradient_op_types.ContainsKey(op_type)) + { + node_def.Attr["_gradient_op_type"].S = tf.compat.as_bytes(new_gradient_op_types[op_type]); + } + else + { + contains_unsaved_custom_gradients = true; + } + } + } + if (contains_unsaved_custom_gradients) + { + // TODO(Rinne): log warnings. + } + + fdef.Signature.Name = _clean_function_name(fdef.Signature.Name); + return orig_name; + } + + private static string _clean_function_name(string name) + { + var match = Regex.Match(name, _FUNCTION_WRAPPER_NAME_REGEX); + if(match.Success) + { + return match.Groups[1].Value; + } + else + { + return name; + } + } + + /// + /// Return a topologic sort of FunctionDefs in a library. + /// + /// + /// + private static IEnumerable _sort_function_defs(FunctionDefLibrary library, Dictionary> function_deps) + { + Dictionary> edges = new(); + Dictionary in_count = new(); + foreach(var item in function_deps) + { + var fname = item.Key; + var deps = item.Value; + if(deps is null || deps.Count() == 0) + { + in_count[fname] = 0; + continue; + } + foreach(var dep in deps) + { + edges.SetDefault(dep, new List()).Add(fname); + if (in_count.ContainsKey(fname)) + { + in_count[fname]++; + } + else + { + in_count[fname] = 1; + } + } + } + var ready = new Stack(library.Function. + Where(x => in_count[x.Signature.Name] == 0) + .Select(x => x.Signature.Name).ToList()); + List output = new(); + while(ready.Count > 0) + { + var node = ready.Pop(); + output.Add(node); + if (!edges.ContainsKey(node)) + { + continue; + } + foreach(var dest in edges[node]) + { + in_count[dest] -= 1; + if (in_count[dest] == 0) + { + ready.Push(dest); + } + } + } + + if(output.Count != library.Function.Count) + { + var failed_to_resolve = in_count.Keys.Except(output); + throw new ValueError($"There is a cyclic dependency between functions. " + + $"Could not resolve ({string.Join(", ", failed_to_resolve)})."); + } + + var reverse = library.Function.ToDictionary(x => x.Signature.Name, x => x); + return output.Select(x => reverse[x]); + } + + private static IEnumerable _list_function_deps(FunctionDef fdef, IEnumerable library_function_names, IDictionary library_gradient_names) + { + HashSet deps = new HashSet(); + foreach(var node_def in fdef.NodeDef) + { + var grad_op_type = _get_gradient_op_type(node_def); + if (library_function_names.Contains(node_def.Op)) + { + deps.Add(node_def.Op); + } + else if(grad_op_type is not null && library_gradient_names.TryGetValue(grad_op_type, out var gradient_name)) + { + deps.Add(gradient_name); + } + else + { + foreach(var attr_value in node_def.Attr.Values) + { + if(attr_value.ValueCase == AttrValue.ValueOneofCase.Func) + { + deps.Add(attr_value.Func.Name); + } + else if(attr_value.ValueCase == AttrValue.ValueOneofCase.List) + { + foreach(var fn in attr_value.List.Func) + { + deps.Add(fn.Name); + } + } + } + } + } + return deps.AsEnumerable(); + } + + private static ByteString _get_gradient_op_type(NodeDef node_def) + { + if(node_def.Attr.ContainsKey("_gradient_op_type") && node_def.Op != "StatefulPartitionedCall" && node_def.Op != "PartitionedCall") + { + return node_def.Attr["_gradient_op_type"].S; + } + return null; + } + public static ConcreteFunction setup_bare_concrete_function(SavedBareConcreteFunction saved_bare_concrete_function, IDictionary concrete_functions) { @@ -30,6 +372,7 @@ namespace Tensorflow.Training.Saving.SavedModel concrete_function.NumPositionArgs = saved_bare_concrete_function.AllowedPositionalArguments; var function_spec = _deserialize_function_spec_as_nonmethod(saved_bare_concrete_function.FunctionSpec); + // TODO(Rinne): set the functiona spec. concrete_function.AddTograph(); return concrete_function; } diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs index dc9e5ba5..7441e4a4 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs @@ -35,6 +35,8 @@ namespace Tensorflow private Dictionary)> _loaded_nodes; private List _nodes; private Dictionary> _node_setters; + private Dictionary _concrete_functions; + private HashSet _restored_concrete_functions; public Loader(SavedObjectGraph object_graph_proto, SavedModel saved_model_proto, string export_dir, CheckpointOptions ckpt_options, LoadOptions save_options, IDictionary)> filters) { @@ -44,6 +46,9 @@ namespace Tensorflow _proto = object_graph_proto; _export_dir = export_dir; // TODO: `this._concrete_functions` and `this._restored_concrete_functions` + _concrete_functions = function_deserialization.load_function_def_library( + meta_graph.GraphDef.Library, _proto); + _restored_concrete_functions = new HashSet(); _checkpoint_options = ckpt_options; _save_options = save_options; @@ -464,9 +469,17 @@ namespace Tensorflow } } - private void _setup_function_captures() + private void _setup_function_captures(string concrete_function_name, Dictionary, Trackable> nodes) { - // TODO: implement it with concrete functions. + if (_restored_concrete_functions.Contains(concrete_function_name)) + { + return; + } + _restored_concrete_functions.Add(concrete_function_name); + var concrete_function = _concrete_functions[concrete_function_name]; + var proto = _proto.ConcreteFunctions[concrete_function_name]; + var inputs = proto.BoundInputs.Select(x => nodes[x]); + function_saved_model_utils.restore_captures(concrete_function, inputs); } private void _setup_remaining_functions() @@ -625,7 +638,7 @@ namespace Tensorflow var fn = function_deserialization.recreate_function(proto, null); foreach (var name in proto.ConcreteFunctions) { - _setup_function_captures(); + _setup_function_captures(name, dependencies); } return (fn, setattr); } @@ -633,8 +646,9 @@ namespace Tensorflow private (ConcreteFunction, Action) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, Dictionary, Trackable> dependencies) { - throw new NotImplementedException(); - //var fn = function_deserialization.setup_bare_concrete_function(proto, ) + var fn = function_deserialization.setup_bare_concrete_function(proto, _concrete_functions); + _setup_function_captures(proto.ConcreteFunctionName, dependencies); + return (fn, setattr); } // TODO: remove this to a common class. diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/nested_structure_coder.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/nested_structure_coder.cs new file mode 100644 index 00000000..ac8b4cf8 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/nested_structure_coder.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Training.Saving.SavedModel +{ + //public class nested_structure_coder + //{ + // public static List decode_proto(StructuredValue proto) + // { + // return proto s + // } + //} +} diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 48d8b5c5..59081ecf 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -572,6 +572,11 @@ namespace Tensorflow return get_default_graph().building_function; } + public static SafeTensorHandle get_resource_handle_data(Tensor graph_op) + { + throw new NotImplementedException(); + } + public static void dismantle_graph(Graph graph) { diff --git a/src/TensorFlowNET.Keras/Saving/KerasMetaData.cs b/src/TensorFlowNET.Keras/Saving/KerasMetaData.cs index e9839850..52e32b7c 100644 --- a/src/TensorFlowNET.Keras/Saving/KerasMetaData.cs +++ b/src/TensorFlowNET.Keras/Saving/KerasMetaData.cs @@ -8,9 +8,14 @@ namespace Tensorflow.Keras.Saving { public class KerasMetaData { + [JsonProperty("name")] public string Name { get; set; } [JsonProperty("class_name")] public string ClassName { get; set; } + [JsonProperty("trainable")] + public bool Trainable { get; set; } + [JsonProperty("dtype")] + public TF_DataType DType { get; set; } = TF_DataType.DtInvalid; [JsonProperty("is_graph_network")] public bool IsGraphNetwork { get; set; } [JsonProperty("shared_object_id")] @@ -20,5 +25,13 @@ namespace Tensorflow.Keras.Saving public JObject Config { get; set; } [JsonProperty("build_input_shape")] public TensorShapeConfig BuildInputShape { get; set; } + [JsonProperty("batch_input_shape")] + public TensorShapeConfig BatchInputShape { get; set; } + [JsonProperty("activity_regularizer")] + public IRegularizer ActivityRegularizer { get; set; } + [JsonProperty("input_spec")] + public JToken InputSpec { get; set; } + [JsonProperty("stateful")] + public bool? Stateful { get; set; } } } diff --git a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs index fffc2bac..898eb18f 100644 --- a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs +++ b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs @@ -26,7 +26,7 @@ namespace Tensorflow.Keras.Saving { public class KerasObjectLoader { - private static readonly IDictionary PUBLIC_ATTRIBUTES = new CommonEndPoints().CheckpointableObjects; + internal static readonly IDictionary PUBLIC_ATTRIBUTES = new CommonEndPoints().CheckpointableObjects; private SavedMetadata _metadata; private SavedObjectGraph _proto; private Dictionary _node_paths = new Dictionary(); @@ -311,6 +311,10 @@ namespace Tensorflow.Keras.Saving { (obj, setter) = _revive_custom_object(identifier, metadata); } + if(obj is null) + { + throw new ValueError($"Cannot revive {metadata.Name} from the config or customized object."); + } Debug.Assert(obj is Layer); _maybe_add_serialized_attributes(obj as Layer, metadata); return (obj, setter); @@ -349,8 +353,14 @@ namespace Tensorflow.Keras.Saving private (Trackable, Action) _revive_custom_object(string identifier, KerasMetaData metadata) { - // TODO(Rinne): implement it. - throw new NotImplementedException(); + if(identifier == SavedModel.Constants.LAYER_IDENTIFIER) + { + return RevivedLayer.init_from_metadata(metadata); + } + else + { + throw new NotImplementedException(); + } } Model _revive_graph_network(string identifier, KerasMetaData metadata, int node_id) @@ -403,9 +413,13 @@ namespace Tensorflow.Keras.Saving var obj = generic_utils.deserialize_keras_object(class_name, config); + if(obj is null) + { + return null; + } obj.Name = metadata.Name; // TODO(Rinne): add `trainable`, `dtype`, `stateful` and `save_spec` - + var built = _try_build_layer(obj, node_id, metadata.BuildInputShape); if (!built) diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/ReviveUtils.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/ReviveUtils.cs new file mode 100644 index 00000000..4dc56130 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/ReviveUtils.cs @@ -0,0 +1,62 @@ +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using System.Text.RegularExpressions; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; +using Tensorflow.Train; + +namespace Tensorflow.Keras.Saving.SavedModel +{ + internal static class ReviveUtils + { + public static T recursively_deserialize_keras_object(JToken config) + { + throw new NotImplementedException(); + if(config is JObject jobject) + { + if (jobject.ContainsKey("class_name")) + { + + } + } + } + + public static void _revive_setter(object layer, object name, object value) + { + Debug.Assert(name is string); + Debug.Assert(layer is Layer); + if (KerasObjectLoader.PUBLIC_ATTRIBUTES.ContainsKey(name as string)) + { + if (value is Trackable trackable) + { + (layer as Layer)._track_trackable(trackable, name as string); + } + (layer as Layer).SerializedAttributes[name] = JToken.FromObject(value); + } + else if (layer is Functional functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success) + { + Debug.Assert(value is Trackable); + functional._track_trackable(value as Trackable, name as string); + } + else + { + var properties = layer.GetType().GetProperties(); + foreach (var p in properties) + { + if ((string)name == p.Name) + { + if(p.GetValue(layer) is not null) + { + return; + } + p.SetValue(layer, value); + return; + } + } + } + } + } +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedConfig.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedConfig.cs new file mode 100644 index 00000000..036d517b --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedConfig.cs @@ -0,0 +1,37 @@ +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Saving.SavedModel +{ + [JsonConverter(typeof(CustomizedRevivedConfigJsonConverter))] + public class RevivedConfig: IKerasConfig + { + public JObject Config { get; set; } + } + + public class CustomizedRevivedConfigJsonConverter : JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(RevivedConfig); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + ((RevivedConfig)value).Config.WriteTo(writer); + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + var config = (JObject)serializer.Deserialize(reader, typeof(JObject)); + return new RevivedConfig() { Config = config }; + } + } +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs new file mode 100644 index 00000000..cb375c9c --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs @@ -0,0 +1,73 @@ +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; +using Tensorflow.Keras.Saving.SavedModel; + +namespace Tensorflow.Keras.Saving.SavedModel +{ + public class RevivedLayer: Layer + { + public static (RevivedLayer, Action) init_from_metadata(KerasMetaData metadata) + { + LayerArgs args = new LayerArgs() + { + Name = metadata.Name, + Trainable = metadata.Trainable + }; + if(metadata.DType != TF_DataType.DtInvalid) + { + args.DType = metadata.DType; + } + if(metadata.BatchInputShape is not null) + { + args.BatchInputShape = metadata.BatchInputShape; + } + + RevivedLayer revived_obj = new RevivedLayer(args); + + // TODO(Rinne): implement `expects_training_arg`. + var config = metadata.Config; + if (generic_utils.validate_config(config)) + { + revived_obj._config = new RevivedConfig() { Config = config }; + } + if(metadata.InputSpec is not null) + { + throw new NotImplementedException(); + } + if(metadata.ActivityRegularizer is not null) + { + throw new NotImplementedException(); + } + // TODO(Rinne): `_is_feature_layer` + if(metadata.Stateful is not null) + { + revived_obj.stateful = metadata.Stateful.Value; + } + + return (revived_obj, ReviveUtils._revive_setter); + } + + private RevivedConfig _config = null; + + public RevivedLayer(LayerArgs args): base(args) + { + + } + + public override string ToString() + { + return $"Customized keras layer: {Name}."; + } + + public override IKerasConfig get_config() + { + return _config; + } + } +} diff --git a/src/TensorFlowNET.Keras/Utils/generic_utils.cs b/src/TensorFlowNET.Keras/Utils/generic_utils.cs index 03acce0c..1194bebf 100644 --- a/src/TensorFlowNET.Keras/Utils/generic_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/generic_utils.cs @@ -23,6 +23,7 @@ using System.Data; using System.Diagnostics; using System.Linq; using System.Reflection; +using System.Security.AccessControl; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Layers; @@ -60,6 +61,10 @@ namespace Tensorflow.Keras.Utils public static Layer deserialize_keras_object(string class_name, JToken config) { var argType = Assembly.Load("Tensorflow.Binding").GetType($"Tensorflow.Keras.ArgsDefinition.{class_name}Args"); + if(argType is null) + { + return null; + } var deserializationMethod = typeof(JToken).GetMethods(BindingFlags.Instance | BindingFlags.Public) .Single(x => x.Name == "ToObject" && x.IsGenericMethodDefinition && x.GetParameters().Count() == 0); var deserializationGenericMethod = deserializationMethod.MakeGenericMethod(argType); @@ -72,6 +77,10 @@ namespace Tensorflow.Keras.Utils public static Layer deserialize_keras_object(string class_name, LayerArgs args) { var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null); + if (layer is null) + { + return null; + } Debug.Assert(layer is Layer); return layer as Layer; } diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs index f4cbccf5..a24ce727 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs @@ -1,10 +1,12 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; using System.Linq; using Tensorflow; using Tensorflow.Keras.Optimizers; using Tensorflow.Keras.UnitTest.Helpers; using Tensorflow.NumPy; using static Tensorflow.Binding; +using static Tensorflow.KerasApi; namespace TensorFlowNET.Keras.UnitTest.SaveModel; @@ -56,4 +58,11 @@ public class SequentialModelLoad model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs); } + + [TestMethod] + public void Temp() + { + var model = tf.keras.models.load_model(@"D:\development\tf.net\tf_test\python_func"); + model.summary(); + } }