| @@ -14,6 +14,7 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using Google.Protobuf; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| @@ -45,6 +46,23 @@ namespace Tensorflow | |||
| { | |||
| return as_text(bytes_or_text, encoding); | |||
| } | |||
| public ByteString as_bytes(ByteString bytes, Encoding encoding = null) | |||
| { | |||
| return bytes; | |||
| } | |||
| public ByteString as_bytes(byte[] bytes, Encoding encoding = null) | |||
| { | |||
| return ByteString.CopyFrom(bytes); | |||
| } | |||
| public ByteString as_bytes(string text, Encoding encoding = null) | |||
| { | |||
| if(encoding is null) | |||
| { | |||
| encoding = Encoding.UTF8; | |||
| } | |||
| return ByteString.CopyFrom(encoding.GetBytes(text)); | |||
| } | |||
| } | |||
| public bool executing_eagerly() | |||
| @@ -54,6 +54,6 @@ namespace Tensorflow | |||
| Dictionary<string, Tensor> input_map = null, | |||
| string[] return_elements = null, | |||
| string name = null, | |||
| OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name, producer_op_list); | |||
| OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name: name, producer_op_list: producer_op_list); | |||
| } | |||
| } | |||
| @@ -156,6 +156,12 @@ namespace Tensorflow.Contexts | |||
| return has_graph_arg; | |||
| } | |||
| public bool has_function(string name) | |||
| { | |||
| ensure_initialized(); | |||
| return c_api.TFE_ContextHasFunction(_handle, name); | |||
| } | |||
| public void restore_mode() | |||
| { | |||
| context_switches.Pop(); | |||
| @@ -0,0 +1,288 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Security.Cryptography; | |||
| using System.Text; | |||
| using Tensorflow.Graphs; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.CppShapeInferenceResult.Types; | |||
| namespace Tensorflow.Framework | |||
| { | |||
| public class function_def_lib | |||
| { | |||
| // TODO(Rinne): process signatures and structured outputs. | |||
| public static FuncGraph function_def_to_graph(FunctionDef fdef, object? structured_input_signature, | |||
| object? structured_outputs, List<TensorShapeProto> input_shapes = null) | |||
| { | |||
| var func_graph = new FuncGraph(fdef.Signature.Name); | |||
| if(input_shapes is null) | |||
| { | |||
| if(fdef.Attr.TryGetValue("_input_shapes", out var input_shapes_attr)) | |||
| { | |||
| var raw_input_shapes = input_shapes_attr.List.Shape; | |||
| input_shapes = new List<TensorShapeProto>(); | |||
| foreach(var (input_shape, arg_def) in raw_input_shapes.Zip(fdef.Signature.InputArg, (x, y) => (x, y))) | |||
| { | |||
| if(arg_def.Type == DataType.DtResource && arg_def.HandleData is not null && arg_def.HandleData.Count > 0) | |||
| { | |||
| input_shapes.Add(null); | |||
| } | |||
| else | |||
| { | |||
| input_shapes.Add(input_shape); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| var (graph_def, nested_to_flat_tensor_name) = function_def_to_graph_def(fdef, input_shapes); | |||
| func_graph.as_default(); | |||
| importer.import_graph_def(graph_def, name: "", validate_colocation_constraints: false); | |||
| var input_tensor_names = fdef.Signature.InputArg.Select(x => nested_to_flat_tensor_name[x.Name]); | |||
| func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); | |||
| var output_tensor_names = fdef.Signature.OutputArg.Select(x => nested_to_flat_tensor_name[fdef.Ret[x.Name]]); | |||
| func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); | |||
| // TODO(Rinne): func_graph.ControlOutputs | |||
| _set_handle_data(func_graph, fdef); | |||
| foreach(var node in graph_def.Node) | |||
| { | |||
| if(node.Attr.TryGetValue("_output_shapes", out var output_shapes)) | |||
| { | |||
| var op = func_graph.get_operation_by_name(node.Name); | |||
| foreach(var (output_index, shape) in enumerate(output_shapes.List.Shape.Take(op.outputs.Length))) | |||
| { | |||
| op.outputs[output_index].shape = new Shape(shape); | |||
| } | |||
| } | |||
| } | |||
| Dictionary<long, string> output_names = new(); | |||
| foreach(var (ret_arg_def, tensor_name) in zip(fdef.Signature.OutputArg, output_tensor_names)) | |||
| { | |||
| output_names[ops.tensor_id(func_graph.get_tensor_by_name(tensor_name))] = ret_arg_def.Name; | |||
| } | |||
| // TODO(Rinne): func_graph._output_names = output_names | |||
| func_graph.Exit(); | |||
| return func_graph; | |||
| } | |||
| public static (GraphDef, Dictionary<string, string>) function_def_to_graph_def(FunctionDef fdef, List<TensorShapeProto> input_shapes) | |||
| { | |||
| var graph_def = new GraphDef() | |||
| { | |||
| Versions = new VersionDef() | |||
| { | |||
| Producer = versions.GRAPH_DEF_VERSION, | |||
| MinConsumer = versions.GRAPH_DEF_VERSION_MIN_CONSUMER | |||
| } | |||
| }; | |||
| var default_graph = ops.get_default_graph(); | |||
| if(input_shapes is not null && input_shapes.Count > 0 && input_shapes.Count != fdef.Signature.InputArg.Count) | |||
| { | |||
| throw new ValueError($"Length of `input_shapes` must match the number " + | |||
| $"of `input_arg`s in `fdef`. Got {input_shapes.Count} `input_shapes` and " + | |||
| $"{fdef.Signature.InputArg.Count} `input_arg`s."); | |||
| } | |||
| foreach(var (i, arg_def) in enumerate(fdef.Signature.InputArg)) | |||
| { | |||
| NodeDef node_def = new(); | |||
| node_def.Name = arg_def.Name; | |||
| node_def.Op = "Placeholder"; | |||
| node_def.Attr["dtype"] = new AttrValue() | |||
| { | |||
| Type = arg_def.Type | |||
| }; | |||
| if(input_shapes is not null && input_shapes.Count > 0 && input_shapes[i] is not null) | |||
| { | |||
| var input_shape = input_shapes[i]; | |||
| // skip the condition that input_shape is not `TensorShapeProto`. | |||
| AttrValue shape = new AttrValue() | |||
| { | |||
| Shape = new TensorShapeProto() | |||
| }; | |||
| shape.Shape = new TensorShapeProto(input_shape); | |||
| node_def.Attr["shape"] = shape; | |||
| } | |||
| if (!fdef.ArgAttr.ContainsKey((uint)i)) | |||
| { | |||
| fdef.ArgAttr[(uint)i] = new FunctionDef.Types.ArgAttrs(); | |||
| } | |||
| var arg_attrs = fdef.ArgAttr[(uint)i].Attr; | |||
| foreach(var k in arg_attrs.Keys) | |||
| { | |||
| if(k == "_output_shapes") | |||
| { | |||
| if (arg_attrs[k].ValueCase == AttrValue.ValueOneofCase.List) | |||
| { | |||
| node_def.Attr["shape"].Shape = new TensorShapeProto(arg_attrs[k].List.Shape[0]); | |||
| } | |||
| else if (arg_attrs[k].ValueCase == AttrValue.ValueOneofCase.Shape) | |||
| { | |||
| node_def.Attr["shape"].Shape = new TensorShapeProto(arg_attrs[k].Shape); | |||
| } | |||
| } | |||
| else if (k.StartsWith("_")) | |||
| { | |||
| if (!node_def.Attr.ContainsKey(k)) | |||
| { | |||
| node_def.Attr[k] = new AttrValue(); | |||
| } | |||
| node_def.Attr[k] = new AttrValue(arg_attrs[k]); | |||
| } | |||
| } | |||
| graph_def.Node.Add(node_def); | |||
| } | |||
| graph_def.Node.AddRange(fdef.NodeDef); | |||
| Dictionary<string, string> nested_to_flat_tensor_name = new(); | |||
| foreach(var arg_def in fdef.Signature.InputArg) | |||
| { | |||
| nested_to_flat_tensor_name[arg_def.Name] = $"{arg_def.Name}:0"; | |||
| string control_name = "^" + arg_def.Name; | |||
| nested_to_flat_tensor_name[control_name] = control_name; | |||
| } | |||
| foreach(var node_def in fdef.NodeDef) | |||
| { | |||
| var graph = default_graph; | |||
| // TODO(Rinne): The `Graph` lacks `_functions`, needed to be implemented in the future. | |||
| while(graph.OuterGraph is not null) | |||
| { | |||
| graph = graph.OuterGraph; | |||
| } | |||
| var op_def = default_graph.GetOpDef(node_def.Op); | |||
| foreach(var attr in op_def.Attr) | |||
| { | |||
| if(attr.Type == "func") | |||
| { | |||
| var fname = node_def.Attr[attr.Name].Func.Name; | |||
| if (!is_function(fname)) | |||
| { | |||
| throw new ValueError($"Function {fname} was not found. Please make sure " + | |||
| $"the FunctionDef `fdef` is correct."); | |||
| } | |||
| } | |||
| else if(attr.Type == "list(func)") | |||
| { | |||
| foreach(var fn in node_def.Attr[attr.Name].List.Func) | |||
| { | |||
| var fname = fn.Name; | |||
| if (!is_function(fname)) | |||
| { | |||
| throw new ValueError($"Function {fname} was not found. Please make " + | |||
| $"sure the FunctionDef `fdef` is correct."); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| int flattened_index = 0; | |||
| foreach(var arg_def in op_def.OutputArg) | |||
| { | |||
| var num_args = _get_num_args(arg_def, node_def); | |||
| for(int i = 0; i < num_args; i++) | |||
| { | |||
| var nested_name = $"{node_def.Name}:{arg_def.Name}:{i}"; | |||
| var flat_name = $"{node_def.Name}:{flattened_index}"; | |||
| nested_to_flat_tensor_name[nested_name] = flat_name; | |||
| flattened_index++; | |||
| } | |||
| } | |||
| string control_name = "^" + node_def.Name; | |||
| nested_to_flat_tensor_name[control_name] = control_name; | |||
| } | |||
| foreach(var node_def in graph_def.Node) | |||
| { | |||
| for(int i = 0; i < node_def.Input.Count; i++) | |||
| { | |||
| node_def.Input[i] = nested_to_flat_tensor_name[node_def.Input[i]]; | |||
| } | |||
| } | |||
| return (graph_def, nested_to_flat_tensor_name); | |||
| } | |||
| private static void _set_handle_data(FuncGraph func_graph, FunctionDef fdef) | |||
| { | |||
| foreach(var (tensor, arg_def) in zip(func_graph.Inputs, fdef.Signature.InputArg).Concat(zip(func_graph.Outputs, fdef.Signature.OutputArg))) | |||
| { | |||
| if(arg_def.HandleData is not null && arg_def.HandleData.Count > 0) | |||
| { | |||
| tensor.shape = Shape.Scalar; | |||
| var shape_and_type = arg_def.HandleData[0]; | |||
| var handle_data = new HandleData(); | |||
| handle_data.IsSet = true; | |||
| handle_data.ShapeAndType.Add(new HandleShapeAndType() | |||
| { | |||
| Shape = shape_and_type.Shape, | |||
| Dtype = shape_and_type.Dtype | |||
| }); | |||
| resource_variable_ops._set_handle_shapes_and_types(tensor, handle_data, true); | |||
| } | |||
| } | |||
| } | |||
| private static long _get_num_args(OpDef.Types.ArgDef arg_def, NodeDef node_def) | |||
| { | |||
| if (!string.IsNullOrEmpty(arg_def.NumberAttr)) | |||
| { | |||
| return node_def.Attr[arg_def.NumberAttr].I; | |||
| } | |||
| else if(!string.IsNullOrEmpty(arg_def.TypeListAttr)) | |||
| { | |||
| return node_def.Attr[arg_def.TypeListAttr].List.Type.Count; | |||
| } | |||
| else if(arg_def.TypeAttr is not null || arg_def.Type != DataType.DtInvalid) | |||
| { | |||
| return 1; | |||
| } | |||
| else | |||
| { | |||
| throw new ValueError($"Invalid arg_def:\n\n{arg_def}. Please make sure the " + | |||
| $"FunctionDef `fdef` is correct."); | |||
| } | |||
| } | |||
| public static bool is_function(string fname) | |||
| { | |||
| if (tf.Context.executing_eagerly()) | |||
| { | |||
| return tf.Context.has_function(fname); | |||
| } | |||
| else | |||
| { | |||
| var graph = ops.get_default_graph(); | |||
| while(graph is not null) | |||
| { | |||
| if (graph.IsFunction(fname)) | |||
| { | |||
| return true; | |||
| } | |||
| if(graph.OuterGraph is not null) | |||
| { | |||
| graph = graph.OuterGraph; | |||
| } | |||
| else | |||
| { | |||
| return false; | |||
| } | |||
| } | |||
| } | |||
| throw new ValueError("Unexpected behavior happened in runtime, please submit an issue to " + | |||
| "https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
| } | |||
| } | |||
| } | |||
| @@ -17,6 +17,7 @@ | |||
| using Google.Protobuf; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.OpDef.Types; | |||
| @@ -25,9 +26,14 @@ namespace Tensorflow | |||
| { | |||
| public class importer | |||
| { | |||
| public static ITensorOrOperation[] import_graph_def_for_function(GraphDef graph_def, string name = null) | |||
| { | |||
| return import_graph_def(graph_def, validate_colocation_constraints: false, name: name); | |||
| } | |||
| public static ITensorOrOperation[] import_graph_def(GraphDef graph_def, | |||
| Dictionary<string, Tensor> input_map = null, | |||
| string[] return_elements = null, | |||
| bool validate_colocation_constraints = true, | |||
| string name = null, | |||
| OpList producer_op_list = null) | |||
| { | |||
| @@ -60,7 +66,7 @@ namespace Tensorflow | |||
| var scoped_options = c_api_util.ScopedTFImportGraphDefOptions(); | |||
| var status = new Status(); | |||
| _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); | |||
| _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements, validate_colocation_constraints ); | |||
| // need to create a class ImportGraphDefWithResults with IDisposal | |||
| results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status)); | |||
| status.Check(true); | |||
| @@ -73,6 +79,42 @@ namespace Tensorflow | |||
| return _GatherReturnElements(return_elements, graph, results); | |||
| } | |||
| //private static ITensorOrOperation[] _import_graph_def_internal(GraphDef graph_def, Dictionary<string, Tensor> input_map = null, string[] return_elements = null, | |||
| // bool validate_colocation_constraints = true, string name = null, OpList producer_op_list = null) | |||
| //{ | |||
| // graph_def = _ProcessGraphDefParam(graph_def); | |||
| // input_map = _ProcessInputMapParam(input_map); | |||
| // return_elements = _ProcessReturnElementsParam(return_elements); | |||
| // if(producer_op_list is not null) | |||
| // { | |||
| // _RemoveDefaultAttrs(producer_op_list, graph_def); | |||
| // } | |||
| // var graph = ops.get_default_graph(); | |||
| // string prefix = null; | |||
| // tf_with(ops.name_scope(name, "import", input_map.Values), scope => | |||
| // { | |||
| // if (scope is not null) | |||
| // { | |||
| // Debug.Assert(scope.scope_name.EndsWith("/")); | |||
| // prefix = scope.scope_name[scope.scope_name.Length - 1].ToString(); | |||
| // } | |||
| // else | |||
| // { | |||
| // prefix = ""; | |||
| // } | |||
| // input_map = _ConvertInputMapValues(name, input_map); | |||
| // }); | |||
| // var scope_options = c_api_util.ScopedTFImportGraphDefOptions(); | |||
| // var options = scope_options.Options; | |||
| // _PopulateTFImportGraphDefOptions(scope_options, prefix, input_map, return_elements, validate_colocation_constraints); | |||
| //} | |||
| private static ITensorOrOperation[] _GatherReturnElements(string[] requested_return_elements, | |||
| Graph graph, | |||
| TF_ImportGraphDefResults results) | |||
| @@ -113,15 +155,29 @@ namespace Tensorflow | |||
| public static void _PopulateTFImportGraphDefOptions(ImportGraphDefOptions options, | |||
| string prefix, | |||
| Dictionary<string, Tensor> input_map, | |||
| string[] return_elements) | |||
| string[] return_elements, | |||
| bool validate_colocation_constraints) | |||
| { | |||
| c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix); | |||
| c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, (char)1); | |||
| c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options.Options, true); | |||
| foreach (var input in input_map) | |||
| { | |||
| var (src_name, src_index) = _ParseTensorName(input.Key); | |||
| c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name, src_index, input.Value._as_tf_output()); | |||
| var input_src = tf.compat.as_str(input.Key); | |||
| var input_dst = input.Value; | |||
| if (input_src.StartsWith("^")) | |||
| { | |||
| var src_name = tf.compat.as_str(input_src.Substring(1)); | |||
| var dst_op = input_dst._as_tf_output().oper; | |||
| c_api.TF_ImportGraphDefOptionsRemapControlDependency(options.Options, src_name, dst_op); | |||
| } | |||
| else | |||
| { | |||
| var (src_name, src_index) = _ParseTensorName(input.Key); | |||
| src_name = tf.compat.as_str(src_name); | |||
| var dst_output = input_dst._as_tf_output(); | |||
| c_api.TF_ImportGraphDefOptionsAddInputMapping(options.Options, src_name, src_index, dst_output); | |||
| } | |||
| } | |||
| if (return_elements == null) | |||
| @@ -132,15 +188,16 @@ namespace Tensorflow | |||
| if (name.Contains(":")) | |||
| { | |||
| var (op_name, index) = _ParseTensorName(name); | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index); | |||
| op_name = tf.compat.as_str(op_name); | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(options.Options, op_name, index); | |||
| } | |||
| else | |||
| { | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, name); | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOperation(options.Options, tf.compat.as_str(name)); | |||
| } | |||
| } | |||
| // c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(options, validate_colocation_constraints); | |||
| c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(options.Options, validate_colocation_constraints); | |||
| } | |||
| private static (string, int) _ParseTensorName(string tensor_name) | |||
| @@ -173,6 +230,14 @@ namespace Tensorflow | |||
| return graph_def; | |||
| } | |||
| private static GraphDef _ProcessGraphDefParam(GraphDef graph_def) | |||
| { | |||
| var old_graph_def = graph_def; | |||
| graph_def = new GraphDef(old_graph_def); | |||
| return graph_def; | |||
| } | |||
| private static void _SetDefaultAttrValues(NodeDef node_def, OpDef op_def) | |||
| { | |||
| foreach (var attr_def in op_def.Attr) | |||
| @@ -240,6 +305,35 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| private static void _RemoveDefaultAttrs(OpList producer_op_list, GraphDef graph_def) | |||
| { | |||
| var producer_op_dict = producer_op_list.Op.ToDictionary(x => x.Name, x => x); | |||
| foreach (var node in graph_def.Node) | |||
| { | |||
| // Remove any default attr values that aren't in op_def. | |||
| if (producer_op_dict.ContainsKey(node.Op)) | |||
| { | |||
| var op_def = op_def_registry.GetOpDef(node.Op); | |||
| if(op_def is null) | |||
| { | |||
| continue; | |||
| } | |||
| var producer_op_def = producer_op_dict[node.Op]; | |||
| foreach (var key in node.Attr.Keys) | |||
| { | |||
| if (_FindAttrInOpDef(key, op_def) is null) | |||
| { | |||
| var attr_def = _FindAttrInOpDef(key, producer_op_def); | |||
| if (attr_def != null && attr_def.DefaultValue != null && | |||
| node.Attr[key] == attr_def.DefaultValue) | |||
| node.Attr[key].ClearValue(); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| private static AttrDef _FindAttrInOpDef(string name, OpDef op_def) | |||
| { | |||
| return op_def.Attr.FirstOrDefault(x => x.Name == name); | |||
| @@ -0,0 +1,12 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Framework | |||
| { | |||
| public class versions | |||
| { | |||
| public static int GRAPH_DEF_VERSION = 1286; | |||
| public static int GRAPH_DEF_VERSION_MIN_CONSUMER = 0; | |||
| } | |||
| } | |||
| @@ -13,6 +13,7 @@ namespace Tensorflow.Functions | |||
| /// </summary> | |||
| public class ConcreteFunction: Trackable | |||
| { | |||
| protected IEnumerable<Tensor> _captured_inputs; | |||
| internal FuncGraph func_graph; | |||
| internal ForwardBackwardCall forward_backward; | |||
| public Tensor[] Inputs => func_graph.Inputs; | |||
| @@ -29,11 +30,13 @@ namespace Tensorflow.Functions | |||
| public ConcreteFunction(string name) | |||
| { | |||
| func_graph = new FuncGraph(name); | |||
| _captured_inputs = func_graph.external_captures; | |||
| } | |||
| public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs = null) | |||
| { | |||
| func_graph = graph; | |||
| _captured_inputs = func_graph.external_captures; | |||
| ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray()); | |||
| } | |||
| @@ -53,6 +56,7 @@ namespace Tensorflow.Functions | |||
| new[] { output }, | |||
| null); | |||
| func_graph.Exit(); | |||
| _captured_inputs = func_graph.external_captures; | |||
| } | |||
| public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype) | |||
| @@ -73,6 +77,7 @@ namespace Tensorflow.Functions | |||
| new[] { output.variant_tensor }, | |||
| null); | |||
| func_graph.Exit(); | |||
| _captured_inputs = func_graph.external_captures; | |||
| } | |||
| /*public ConcreteFunction(Func<Tensors, Tensors> func, | |||
| @@ -174,6 +179,11 @@ namespace Tensorflow.Functions | |||
| // TODO(Rinne); complete it with `_delayed_rewrite_functions`. | |||
| } | |||
| public void SetExternalCaptures(IEnumerable<Tensor> captures) | |||
| { | |||
| _captured_inputs = captures; | |||
| } | |||
| ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | |||
| { | |||
| var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | |||
| @@ -1,4 +1,5 @@ | |||
| using System; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.Train; | |||
| namespace Tensorflow | |||
| @@ -0,0 +1,12 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Functions | |||
| { | |||
| public interface IGenericFunction | |||
| { | |||
| object[] Apply(params object[] args); | |||
| ConcreteFunction get_concrete_function(params object[] args); | |||
| } | |||
| } | |||
| @@ -0,0 +1,88 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Operations; | |||
| using Tensorflow.Train; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Functions | |||
| { | |||
| public static class function_saved_model_utils | |||
| { | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <param name="concrete_function"></param> | |||
| /// <param name="inputs">a list tensors or other objects (such as variables) which | |||
| /// contain tensors that were originally captured by the function</param> | |||
| public static void restore_captures(ConcreteFunction concrete_function, IEnumerable<object> inputs) | |||
| { | |||
| var bound_inputs = inputs?.Select(obj => | |||
| { | |||
| if(obj is Tensor tensor) | |||
| { | |||
| return get_tensor_from_node(tensor); | |||
| } | |||
| else if(obj is IVariableV1 variable) | |||
| { | |||
| return get_tensor_from_node(variable); | |||
| } | |||
| else | |||
| { | |||
| throw new TypeError("Encountered an type error, please submit an issue to " + | |||
| "https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
| } | |||
| }); | |||
| var bound_variables = inputs.TakeWhile(obj => obj is IVariableV1); | |||
| List<Tensor> captured_inputs_list = new(); | |||
| // TODO(Rinne): concrete_function.set_variables(bound_variables) | |||
| if (bound_inputs is not null) | |||
| { | |||
| foreach(var (bound_input, internal_capture) in zip(bound_inputs, concrete_function.Inputs.Skip(concrete_function.Inputs.Length - bound_inputs.Count()))) | |||
| { | |||
| if(hasattr(bound_input, "__tf_experimental_restore_capture__")) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| else | |||
| { | |||
| captured_inputs_list.Add(bound_input); | |||
| concrete_function.func_graph.replace_capture(bound_input, internal_capture); | |||
| if(internal_capture.dtype == dtypes.resource) | |||
| { | |||
| // skip the check of variable. | |||
| handle_data_util.copy_handle_data(bound_input, internal_capture); | |||
| } | |||
| concrete_function.func_graph.capture(bound_input); | |||
| } | |||
| } | |||
| } | |||
| if(captured_inputs_list.Any(inp => inp is null)) | |||
| { | |||
| // TODO(Rinne): add warnings. | |||
| } | |||
| concrete_function.SetExternalCaptures(captured_inputs_list); | |||
| } | |||
| public static Tensor get_tensor_from_node(Tensor node) | |||
| { | |||
| return node; | |||
| } | |||
| public static Tensor get_tensor_from_node(IVariableV1 node) | |||
| { | |||
| if (resource_variable_ops.is_resource_variable(node)) | |||
| { | |||
| return node.Handle; | |||
| } | |||
| else | |||
| { | |||
| throw new TypeError("Encountered an type error, please submit an issue to " + | |||
| "https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,14 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Gradients | |||
| { | |||
| public class custom_gradient | |||
| { | |||
| public static string generate_name() | |||
| { | |||
| return $"CustomGradient-{ops.uid()}"; | |||
| } | |||
| } | |||
| } | |||
| @@ -1,6 +1,7 @@ | |||
| using MethodBoundaryAspect.Fody.Attributes; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Functions; | |||
| @@ -21,8 +22,9 @@ namespace Tensorflow.Graphs | |||
| public override void OnEntry(MethodExecutionArgs args) | |||
| { | |||
| File.WriteAllText(@"D:\temp\for_test.txt", "jyfgjyfjhfjhc"); | |||
| // TODO: func_name can be cache in FullName + Args | |||
| func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}_{ops.uid_function()}"; | |||
| func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}"; | |||
| if (functions.ContainsKey(func_name)) | |||
| { | |||
| @@ -56,6 +56,11 @@ public class FuncGraph : Graph, IDisposable | |||
| _handle = handle; | |||
| } | |||
| public void replace_capture(Tensor tensor, Tensor placeholder) | |||
| { | |||
| _captures[tensor.Id] = (tensor, placeholder); | |||
| } | |||
| public void ToGraph(Operation[] opers, | |||
| Tensor[] inputs, Tensor[] outputs, | |||
| string[] output_names) | |||
| @@ -146,6 +146,12 @@ namespace Tensorflow | |||
| return ops.set_default_graph(this); | |||
| } | |||
| public bool IsFunction(string name) | |||
| { | |||
| // TODO(Rinne): deal with `_functions`. | |||
| throw new NotImplementedException(); | |||
| } | |||
| private Tensor _as_graph_element(object obj) | |||
| { | |||
| if (obj is RefVariable var) | |||
| @@ -28,6 +28,8 @@ public sealed class ImportGraphDefOptions | |||
| _handle = c_api.TF_NewImportGraphDefOptions(); | |||
| } | |||
| public SafeImportGraphDefOptionsHandle Options => _handle; | |||
| public void AddReturnOutput(string name, int index) | |||
| { | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); | |||
| @@ -185,6 +185,9 @@ namespace Tensorflow | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_ImportGraphDefOptionsAddReturnOperation(SafeImportGraphDefOptionsHandle opts, string oper_name); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_ImportGraphDefOptionsSetValidateColocationConstraints(SafeImportGraphDefOptionsHandle options, bool validate_colocation_constraints); | |||
| /// <summary> | |||
| /// Add an output in `graph_def` to be returned via the `return_outputs` output | |||
| /// parameter of TF_GraphImportGraphDef(). If the output is remapped via an input | |||
| @@ -246,7 +249,7 @@ namespace Tensorflow | |||
| /// <param name="ops">TF_ImportGraphDefOptions*</param> | |||
| /// <param name="uniquify_prefix">unsigned char</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(SafeImportGraphDefOptionsHandle ops, char uniquify_prefix); | |||
| public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(SafeImportGraphDefOptionsHandle ops, bool uniquify_prefix); | |||
| /// <summary> | |||
| /// Fetches the return operations requested via | |||
| @@ -0,0 +1,28 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Eager; | |||
| namespace Tensorflow.Operations | |||
| { | |||
| public static class handle_data_util | |||
| { | |||
| public static void copy_handle_data(Tensor source_t, Tensor target_t) | |||
| { | |||
| if(target_t.dtype == dtypes.resource || target_t.dtype == dtypes.variant) | |||
| { | |||
| SafeTensorHandle handle_data; | |||
| if(source_t is EagerTensor) | |||
| { | |||
| handle_data = source_t.Handle; | |||
| } | |||
| else | |||
| { | |||
| handle_data = ops.get_resource_handle_data(source_t); | |||
| } | |||
| throw new NotImplementedException(); | |||
| //if(handle_data is not null && handle_data.) | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -126,7 +126,7 @@ namespace Tensorflow | |||
| /// <param name="handle"></param> | |||
| /// <param name="handle_data"></param> | |||
| /// <param name="graph_mode"></param> | |||
| private static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode) | |||
| internal static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode) | |||
| { | |||
| if (!graph_mode) | |||
| return; | |||
| @@ -5,6 +5,7 @@ | |||
| #pragma warning disable 1591, 0612, 3021 | |||
| #region Designer generated code | |||
| using Tensorflow.Framework.Models; | |||
| using pb = global::Google.Protobuf; | |||
| using pbc = global::Google.Protobuf.Collections; | |||
| using pbr = global::Google.Protobuf.Reflection; | |||
| @@ -2589,9 +2590,17 @@ namespace Tensorflow { | |||
| } | |||
| } | |||
| #region Nested types | |||
| /// <summary>Container for nested types declared in the FunctionSpec message type.</summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| //public static FunctionSpec from_function_and_signature(string csharp_function, IEnumerable<TensorSpec> input_signature, bool is_pure = false, object jit_compile = null) | |||
| //{ | |||
| // // TODO(Rinne): _validate_signature(input_signature) | |||
| // // TODO(Rinne): _validate_python_function(python_function, input_signature) | |||
| //} | |||
| #region Nested types | |||
| /// <summary>Container for nested types declared in the FunctionSpec message type.</summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public static partial class Types { | |||
| /// <summary> | |||
| /// Whether the function should be compiled by XLA. | |||
| @@ -1,14 +1,24 @@ | |||
| using System; | |||
| using Google.Protobuf; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Text; | |||
| using System.Text.RegularExpressions; | |||
| using Tensorflow.Framework; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.Gradients; | |||
| using Tensorflow.Graphs; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Training.Saving.SavedModel | |||
| { | |||
| public static class function_deserialization | |||
| { | |||
| private static string _INFERENCE_PREFIX = "__inference_"; | |||
| private static string _FUNCTION_WRAPPER_NAME_REGEX = $@"^{_INFERENCE_PREFIX}(.*)_\d+$"; | |||
| /// <summary> | |||
| /// Creates a `Function` from a `SavedFunction`. | |||
| /// </summary> | |||
| @@ -22,6 +32,338 @@ namespace Tensorflow.Training.Saving.SavedModel | |||
| return null; | |||
| } | |||
| public static Dictionary<string, ConcreteFunction> load_function_def_library(FunctionDefLibrary library, | |||
| SavedObjectGraph saved_object_graph = null, string load_shared_name_suffix = null, object? wrapper_function = null) | |||
| { | |||
| var library_function_names = library.Function.Select(x => x.Signature.Name).Distinct(); | |||
| Dictionary<string, ConcreteFunction> functions = new(); | |||
| Dictionary<string, ConcreteFunction> renamed_functions = new(); | |||
| Graph graph; | |||
| if (ops.executing_eagerly_outside_functions()) | |||
| { | |||
| graph = new Graph(); | |||
| } | |||
| else | |||
| { | |||
| graph = ops.get_default_graph(); | |||
| } | |||
| if(load_shared_name_suffix is null) | |||
| { | |||
| load_shared_name_suffix = $"_load_{ops.uid()}"; | |||
| } | |||
| Dictionary<ByteString, string> library_gradient_names = new(); | |||
| Dictionary<ByteString, string> new_gradient_op_types = new(); | |||
| Dictionary<string, string> gradients_to_register = new(); | |||
| foreach (var gdef in library.RegisteredGradients) | |||
| { | |||
| if(gdef.RegisteredOpType is not null) | |||
| { | |||
| var new_op_type = custom_gradient.generate_name(); | |||
| var old_op_type = tf.compat.as_bytes(gdef.RegisteredOpType); | |||
| library_gradient_names[old_op_type] = gdef.GradientFunc; | |||
| new_gradient_op_types[old_op_type] = new_op_type; | |||
| gradients_to_register[gdef.GradientFunc] = new_op_type; | |||
| } | |||
| } | |||
| Dictionary<string, IEnumerable<string>> function_deps = new(); | |||
| foreach(var fdef in library.Function) | |||
| { | |||
| function_deps[fdef.Signature.Name] = _list_function_deps(fdef, library_function_names, library_gradient_names); | |||
| } | |||
| Dictionary<string, ConcreteFunction> loaded_gradients = new(); | |||
| int aa = 0; | |||
| var temp = _sort_function_defs(library, function_deps); | |||
| foreach (var fdef in temp) | |||
| { | |||
| aa++; | |||
| var orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, new_gradient_op_types); | |||
| if(saved_object_graph is not null && saved_object_graph.ConcreteFunctions.ContainsKey(orig_name)) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| //var proto = saved_object_graph.ConcreteFunctions[orig_name]; | |||
| //throw new NotImplementedException(); | |||
| } | |||
| graph.as_default(); | |||
| var func_graph = function_def_lib.function_def_to_graph(fdef, null, null); | |||
| graph.Exit(); | |||
| _restore_gradient_functions(func_graph, renamed_functions, loaded_gradients); | |||
| foreach(var dep in function_deps[orig_name]) | |||
| { | |||
| functions[dep].AddTograph(func_graph); | |||
| } | |||
| if (fdef.Attr.ContainsKey("_input_shapes")) | |||
| { | |||
| fdef.Attr.Remove("_input_shapes"); | |||
| } | |||
| var func = new ConcreteFunction(func_graph, fdef.Attr.ToDictionary(x => x.Key, x => x.Value.S.ToString())); | |||
| if(wrapper_function is not null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| func.AddTograph(graph); | |||
| functions[orig_name] = func; | |||
| renamed_functions[func.Name] = func; | |||
| if(func_graph.get_operations().Any(op => op.op.type == "TRTEngineOp")) | |||
| { | |||
| func.AddTograph(ops.get_default_graph()); | |||
| } | |||
| if (gradients_to_register.ContainsKey(orig_name)) | |||
| { | |||
| var gradient_op_type = gradients_to_register[orig_name]; | |||
| loaded_gradients[gradient_op_type] = func; | |||
| // TODO(Rinne): deal with gradient registry. | |||
| //new RegisteredGradient() { RegisteredOpType = gradient_op_type }. | |||
| } | |||
| } | |||
| return functions; | |||
| } | |||
| public static void fix_node_def(NodeDef node_def, IDictionary<string, ConcreteFunction> functions, string shared_name_suffix) | |||
| { | |||
| if (functions.ContainsKey(node_def.Op)) | |||
| { | |||
| node_def.Op = functions[node_def.Op].Name; | |||
| } | |||
| foreach(var attr_value in node_def.Attr.Values) | |||
| { | |||
| if(attr_value.ValueCase == AttrValue.ValueOneofCase.Func) | |||
| { | |||
| attr_value.Func.Name = functions[attr_value.Func.Name].Name; | |||
| } | |||
| else if(attr_value.ValueCase == AttrValue.ValueOneofCase.List) | |||
| { | |||
| foreach(var fn in attr_value.List.Func) | |||
| { | |||
| fn.Name = functions[fn.Name].Name; | |||
| } | |||
| } | |||
| } | |||
| if(node_def.Op == "HashTableV2") | |||
| { | |||
| if(!node_def.Attr.ContainsKey("use_node_name_sharing") || !node_def.Attr["use_node_name_sharing"].B) | |||
| { | |||
| node_def.Attr["use_node_name_sharing"].B = true; | |||
| shared_name_suffix += $"_{ops.uid()}"; | |||
| } | |||
| } | |||
| var op_def = op_def_registry.GetOpDef(node_def.Op); | |||
| if(op_def is not null) | |||
| { | |||
| var attr = op_def.Attr.Where(x => x.Name == "shared_name").FirstOrDefault(); | |||
| if(attr is not null) | |||
| { | |||
| ByteString shared_name = null; | |||
| if(node_def.Attr.ContainsKey("shared_name") && node_def.Attr["shared_name"].S is not null) | |||
| { | |||
| shared_name = node_def.Attr["shared_name"].S; | |||
| } | |||
| else if(attr.DefaultValue.S is not null) | |||
| { | |||
| shared_name = tf.compat.as_bytes(attr.DefaultValue.S); | |||
| } | |||
| if(shared_name is null) | |||
| { | |||
| shared_name = tf.compat.as_bytes(node_def.Name); | |||
| } | |||
| node_def.Attr["shared_name"].S = ByteString.CopyFrom(shared_name.Concat(tf.compat.as_bytes(node_def.Name)).ToArray()); | |||
| } | |||
| } | |||
| } | |||
| private static void _restore_gradient_functions(FuncGraph func_graph, Dictionary<string, ConcreteFunction> renamed_functions, Dictionary<string, ConcreteFunction> loaded_gradients) | |||
| { | |||
| foreach(var op in func_graph.get_operations()) | |||
| { | |||
| if(op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall") | |||
| { | |||
| var function = renamed_functions[tf.compat.as_bytes(op.op.node_def.Attr["f"].Func.Name).ToString()]; | |||
| // TODO(Rinne): deal with `op._gradient_function`. | |||
| } | |||
| string gradient_op_type = null; | |||
| try | |||
| { | |||
| gradient_op_type = op.op.get_attr("_gradient_op_type") as string; | |||
| } | |||
| catch(Exception e) | |||
| { | |||
| continue; | |||
| } | |||
| if (loaded_gradients.ContainsKey(gradient_op_type)) | |||
| { | |||
| var grad_fn = loaded_gradients[gradient_op_type]; | |||
| grad_fn.NumPositionArgs = op.op.inputs.Length; | |||
| grad_fn.ArgKeywords = op.op.inputs._inputs.Select(x => x.name); | |||
| } | |||
| } | |||
| } | |||
| private static string _fix_fdef_in_place(FunctionDef fdef, IDictionary<string, ConcreteFunction> functions, string shared_name_suffix, | |||
| IDictionary<ByteString, string> new_gradient_op_types) | |||
| { | |||
| var orig_name = fdef.Signature.Name; | |||
| bool contains_unsaved_custom_gradients = false; | |||
| foreach(var node_def in fdef.NodeDef) | |||
| { | |||
| fix_node_def(node_def, functions, shared_name_suffix); | |||
| var op_type = _get_gradient_op_type(node_def); | |||
| if(op_type is not null) | |||
| { | |||
| if (new_gradient_op_types.ContainsKey(op_type)) | |||
| { | |||
| node_def.Attr["_gradient_op_type"].S = tf.compat.as_bytes(new_gradient_op_types[op_type]); | |||
| } | |||
| else | |||
| { | |||
| contains_unsaved_custom_gradients = true; | |||
| } | |||
| } | |||
| } | |||
| if (contains_unsaved_custom_gradients) | |||
| { | |||
| // TODO(Rinne): log warnings. | |||
| } | |||
| fdef.Signature.Name = _clean_function_name(fdef.Signature.Name); | |||
| return orig_name; | |||
| } | |||
| private static string _clean_function_name(string name) | |||
| { | |||
| var match = Regex.Match(name, _FUNCTION_WRAPPER_NAME_REGEX); | |||
| if(match.Success) | |||
| { | |||
| return match.Groups[1].Value; | |||
| } | |||
| else | |||
| { | |||
| return name; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Return a topologic sort of FunctionDefs in a library. | |||
| /// </summary> | |||
| /// <param name="library"></param> | |||
| /// <param name="function_deps"></param> | |||
| private static IEnumerable<FunctionDef> _sort_function_defs(FunctionDefLibrary library, Dictionary<string, IEnumerable<string>> function_deps) | |||
| { | |||
| Dictionary<string, IList<string>> edges = new(); | |||
| Dictionary<string, int> in_count = new(); | |||
| foreach(var item in function_deps) | |||
| { | |||
| var fname = item.Key; | |||
| var deps = item.Value; | |||
| if(deps is null || deps.Count() == 0) | |||
| { | |||
| in_count[fname] = 0; | |||
| continue; | |||
| } | |||
| foreach(var dep in deps) | |||
| { | |||
| edges.SetDefault(dep, new List<string>()).Add(fname); | |||
| if (in_count.ContainsKey(fname)) | |||
| { | |||
| in_count[fname]++; | |||
| } | |||
| else | |||
| { | |||
| in_count[fname] = 1; | |||
| } | |||
| } | |||
| } | |||
| var ready = new Stack<string>(library.Function. | |||
| Where(x => in_count[x.Signature.Name] == 0) | |||
| .Select(x => x.Signature.Name).ToList()); | |||
| List<string> output = new(); | |||
| while(ready.Count > 0) | |||
| { | |||
| var node = ready.Pop(); | |||
| output.Add(node); | |||
| if (!edges.ContainsKey(node)) | |||
| { | |||
| continue; | |||
| } | |||
| foreach(var dest in edges[node]) | |||
| { | |||
| in_count[dest] -= 1; | |||
| if (in_count[dest] == 0) | |||
| { | |||
| ready.Push(dest); | |||
| } | |||
| } | |||
| } | |||
| if(output.Count != library.Function.Count) | |||
| { | |||
| var failed_to_resolve = in_count.Keys.Except(output); | |||
| throw new ValueError($"There is a cyclic dependency between functions. " + | |||
| $"Could not resolve ({string.Join(", ", failed_to_resolve)})."); | |||
| } | |||
| var reverse = library.Function.ToDictionary(x => x.Signature.Name, x => x); | |||
| return output.Select(x => reverse[x]); | |||
| } | |||
| private static IEnumerable<string> _list_function_deps(FunctionDef fdef, IEnumerable<string> library_function_names, IDictionary<ByteString, string> library_gradient_names) | |||
| { | |||
| HashSet<string> deps = new HashSet<string>(); | |||
| foreach(var node_def in fdef.NodeDef) | |||
| { | |||
| var grad_op_type = _get_gradient_op_type(node_def); | |||
| if (library_function_names.Contains(node_def.Op)) | |||
| { | |||
| deps.Add(node_def.Op); | |||
| } | |||
| else if(grad_op_type is not null && library_gradient_names.TryGetValue(grad_op_type, out var gradient_name)) | |||
| { | |||
| deps.Add(gradient_name); | |||
| } | |||
| else | |||
| { | |||
| foreach(var attr_value in node_def.Attr.Values) | |||
| { | |||
| if(attr_value.ValueCase == AttrValue.ValueOneofCase.Func) | |||
| { | |||
| deps.Add(attr_value.Func.Name); | |||
| } | |||
| else if(attr_value.ValueCase == AttrValue.ValueOneofCase.List) | |||
| { | |||
| foreach(var fn in attr_value.List.Func) | |||
| { | |||
| deps.Add(fn.Name); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return deps.AsEnumerable(); | |||
| } | |||
| private static ByteString _get_gradient_op_type(NodeDef node_def) | |||
| { | |||
| if(node_def.Attr.ContainsKey("_gradient_op_type") && node_def.Op != "StatefulPartitionedCall" && node_def.Op != "PartitionedCall") | |||
| { | |||
| return node_def.Attr["_gradient_op_type"].S; | |||
| } | |||
| return null; | |||
| } | |||
| public static ConcreteFunction setup_bare_concrete_function(SavedBareConcreteFunction saved_bare_concrete_function, | |||
| IDictionary<string, ConcreteFunction> concrete_functions) | |||
| { | |||
| @@ -30,6 +372,7 @@ namespace Tensorflow.Training.Saving.SavedModel | |||
| concrete_function.NumPositionArgs = saved_bare_concrete_function.AllowedPositionalArguments; | |||
| var function_spec = _deserialize_function_spec_as_nonmethod(saved_bare_concrete_function.FunctionSpec); | |||
| // TODO(Rinne): set the functiona spec. | |||
| concrete_function.AddTograph(); | |||
| return concrete_function; | |||
| } | |||
| @@ -35,6 +35,8 @@ namespace Tensorflow | |||
| private Dictionary<int, (Trackable, Action<object, object, object>)> _loaded_nodes; | |||
| private List<Trackable> _nodes; | |||
| private Dictionary<int, Action<object, object, object>> _node_setters; | |||
| private Dictionary<string, ConcreteFunction> _concrete_functions; | |||
| private HashSet<string> _restored_concrete_functions; | |||
| public Loader(SavedObjectGraph object_graph_proto, SavedModel saved_model_proto, string export_dir, | |||
| CheckpointOptions ckpt_options, LoadOptions save_options, IDictionary<string, (Trackable, Action<object, object, object>)> filters) | |||
| { | |||
| @@ -44,6 +46,9 @@ namespace Tensorflow | |||
| _proto = object_graph_proto; | |||
| _export_dir = export_dir; | |||
| // TODO: `this._concrete_functions` and `this._restored_concrete_functions` | |||
| _concrete_functions = function_deserialization.load_function_def_library( | |||
| meta_graph.GraphDef.Library, _proto); | |||
| _restored_concrete_functions = new HashSet<string>(); | |||
| _checkpoint_options = ckpt_options; | |||
| _save_options = save_options; | |||
| @@ -464,9 +469,17 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| private void _setup_function_captures() | |||
| private void _setup_function_captures(string concrete_function_name, Dictionary<Maybe<string, int>, Trackable> nodes) | |||
| { | |||
| // TODO: implement it with concrete functions. | |||
| if (_restored_concrete_functions.Contains(concrete_function_name)) | |||
| { | |||
| return; | |||
| } | |||
| _restored_concrete_functions.Add(concrete_function_name); | |||
| var concrete_function = _concrete_functions[concrete_function_name]; | |||
| var proto = _proto.ConcreteFunctions[concrete_function_name]; | |||
| var inputs = proto.BoundInputs.Select(x => nodes[x]); | |||
| function_saved_model_utils.restore_captures(concrete_function, inputs); | |||
| } | |||
| private void _setup_remaining_functions() | |||
| @@ -625,7 +638,7 @@ namespace Tensorflow | |||
| var fn = function_deserialization.recreate_function(proto, null); | |||
| foreach (var name in proto.ConcreteFunctions) | |||
| { | |||
| _setup_function_captures(); | |||
| _setup_function_captures(name, dependencies); | |||
| } | |||
| return (fn, setattr); | |||
| } | |||
| @@ -633,8 +646,9 @@ namespace Tensorflow | |||
| private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, | |||
| Dictionary<Maybe<string, int>, Trackable> dependencies) | |||
| { | |||
| throw new NotImplementedException(); | |||
| //var fn = function_deserialization.setup_bare_concrete_function(proto, ) | |||
| var fn = function_deserialization.setup_bare_concrete_function(proto, _concrete_functions); | |||
| _setup_function_captures(proto.ConcreteFunctionName, dependencies); | |||
| return (fn, setattr); | |||
| } | |||
| // TODO: remove this to a common class. | |||
| @@ -0,0 +1,14 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Training.Saving.SavedModel | |||
| { | |||
| //public class nested_structure_coder | |||
| //{ | |||
| // public static List<object> decode_proto(StructuredValue proto) | |||
| // { | |||
| // return proto s | |||
| // } | |||
| //} | |||
| } | |||
| @@ -572,6 +572,11 @@ namespace Tensorflow | |||
| return get_default_graph().building_function; | |||
| } | |||
| public static SafeTensorHandle get_resource_handle_data(Tensor graph_op) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public static void dismantle_graph(Graph graph) | |||
| { | |||
| @@ -8,9 +8,14 @@ namespace Tensorflow.Keras.Saving | |||
| { | |||
| public class KerasMetaData | |||
| { | |||
| [JsonProperty("name")] | |||
| public string Name { get; set; } | |||
| [JsonProperty("class_name")] | |||
| public string ClassName { get; set; } | |||
| [JsonProperty("trainable")] | |||
| public bool Trainable { get; set; } | |||
| [JsonProperty("dtype")] | |||
| public TF_DataType DType { get; set; } = TF_DataType.DtInvalid; | |||
| [JsonProperty("is_graph_network")] | |||
| public bool IsGraphNetwork { get; set; } | |||
| [JsonProperty("shared_object_id")] | |||
| @@ -20,5 +25,13 @@ namespace Tensorflow.Keras.Saving | |||
| public JObject Config { get; set; } | |||
| [JsonProperty("build_input_shape")] | |||
| public TensorShapeConfig BuildInputShape { get; set; } | |||
| [JsonProperty("batch_input_shape")] | |||
| public TensorShapeConfig BatchInputShape { get; set; } | |||
| [JsonProperty("activity_regularizer")] | |||
| public IRegularizer ActivityRegularizer { get; set; } | |||
| [JsonProperty("input_spec")] | |||
| public JToken InputSpec { get; set; } | |||
| [JsonProperty("stateful")] | |||
| public bool? Stateful { get; set; } | |||
| } | |||
| } | |||
| @@ -26,7 +26,7 @@ namespace Tensorflow.Keras.Saving | |||
| { | |||
| public class KerasObjectLoader | |||
| { | |||
| private static readonly IDictionary<string, Trackable> PUBLIC_ATTRIBUTES = new CommonEndPoints().CheckpointableObjects; | |||
| internal static readonly IDictionary<string, Trackable> PUBLIC_ATTRIBUTES = new CommonEndPoints().CheckpointableObjects; | |||
| private SavedMetadata _metadata; | |||
| private SavedObjectGraph _proto; | |||
| private Dictionary<int, string> _node_paths = new Dictionary<int, string>(); | |||
| @@ -311,6 +311,10 @@ namespace Tensorflow.Keras.Saving | |||
| { | |||
| (obj, setter) = _revive_custom_object(identifier, metadata); | |||
| } | |||
| if(obj is null) | |||
| { | |||
| throw new ValueError($"Cannot revive {metadata.Name} from the config or customized object."); | |||
| } | |||
| Debug.Assert(obj is Layer); | |||
| _maybe_add_serialized_attributes(obj as Layer, metadata); | |||
| return (obj, setter); | |||
| @@ -349,8 +353,14 @@ namespace Tensorflow.Keras.Saving | |||
| private (Trackable, Action<object, object, object>) _revive_custom_object(string identifier, KerasMetaData metadata) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| throw new NotImplementedException(); | |||
| if(identifier == SavedModel.Constants.LAYER_IDENTIFIER) | |||
| { | |||
| return RevivedLayer.init_from_metadata(metadata); | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| Model _revive_graph_network(string identifier, KerasMetaData metadata, int node_id) | |||
| @@ -403,9 +413,13 @@ namespace Tensorflow.Keras.Saving | |||
| var obj = generic_utils.deserialize_keras_object(class_name, config); | |||
| if(obj is null) | |||
| { | |||
| return null; | |||
| } | |||
| obj.Name = metadata.Name; | |||
| // TODO(Rinne): add `trainable`, `dtype`, `stateful` and `save_spec` | |||
| var built = _try_build_layer(obj, node_id, metadata.BuildInputShape); | |||
| if (!built) | |||
| @@ -0,0 +1,62 @@ | |||
| using Newtonsoft.Json.Linq; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Text; | |||
| using System.Text.RegularExpressions; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Utils; | |||
| using Tensorflow.Train; | |||
| namespace Tensorflow.Keras.Saving.SavedModel | |||
| { | |||
| internal static class ReviveUtils | |||
| { | |||
| public static T recursively_deserialize_keras_object<T>(JToken config) | |||
| { | |||
| throw new NotImplementedException(); | |||
| if(config is JObject jobject) | |||
| { | |||
| if (jobject.ContainsKey("class_name")) | |||
| { | |||
| } | |||
| } | |||
| } | |||
| public static void _revive_setter(object layer, object name, object value) | |||
| { | |||
| Debug.Assert(name is string); | |||
| Debug.Assert(layer is Layer); | |||
| if (KerasObjectLoader.PUBLIC_ATTRIBUTES.ContainsKey(name as string)) | |||
| { | |||
| if (value is Trackable trackable) | |||
| { | |||
| (layer as Layer)._track_trackable(trackable, name as string); | |||
| } | |||
| (layer as Layer).SerializedAttributes[name] = JToken.FromObject(value); | |||
| } | |||
| else if (layer is Functional functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success) | |||
| { | |||
| Debug.Assert(value is Trackable); | |||
| functional._track_trackable(value as Trackable, name as string); | |||
| } | |||
| else | |||
| { | |||
| var properties = layer.GetType().GetProperties(); | |||
| foreach (var p in properties) | |||
| { | |||
| if ((string)name == p.Name) | |||
| { | |||
| if(p.GetValue(layer) is not null) | |||
| { | |||
| return; | |||
| } | |||
| p.SetValue(layer, value); | |||
| return; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,37 @@ | |||
| using Newtonsoft.Json; | |||
| using Newtonsoft.Json.Linq; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.Saving.SavedModel | |||
| { | |||
| [JsonConverter(typeof(CustomizedRevivedConfigJsonConverter))] | |||
| public class RevivedConfig: IKerasConfig | |||
| { | |||
| public JObject Config { get; set; } | |||
| } | |||
| public class CustomizedRevivedConfigJsonConverter : JsonConverter | |||
| { | |||
| public override bool CanConvert(Type objectType) | |||
| { | |||
| return objectType == typeof(RevivedConfig); | |||
| } | |||
| public override bool CanRead => true; | |||
| public override bool CanWrite => true; | |||
| public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | |||
| { | |||
| ((RevivedConfig)value).Config.WriteTo(writer); | |||
| } | |||
| public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||
| { | |||
| var config = (JObject)serializer.Deserialize(reader, typeof(JObject)); | |||
| return new RevivedConfig() { Config = config }; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,73 @@ | |||
| using Newtonsoft.Json.Linq; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Utils; | |||
| using Tensorflow.Keras.Saving.SavedModel; | |||
| namespace Tensorflow.Keras.Saving.SavedModel | |||
| { | |||
| public class RevivedLayer: Layer | |||
| { | |||
| public static (RevivedLayer, Action<object, object, object>) init_from_metadata(KerasMetaData metadata) | |||
| { | |||
| LayerArgs args = new LayerArgs() | |||
| { | |||
| Name = metadata.Name, | |||
| Trainable = metadata.Trainable | |||
| }; | |||
| if(metadata.DType != TF_DataType.DtInvalid) | |||
| { | |||
| args.DType = metadata.DType; | |||
| } | |||
| if(metadata.BatchInputShape is not null) | |||
| { | |||
| args.BatchInputShape = metadata.BatchInputShape; | |||
| } | |||
| RevivedLayer revived_obj = new RevivedLayer(args); | |||
| // TODO(Rinne): implement `expects_training_arg`. | |||
| var config = metadata.Config; | |||
| if (generic_utils.validate_config(config)) | |||
| { | |||
| revived_obj._config = new RevivedConfig() { Config = config }; | |||
| } | |||
| if(metadata.InputSpec is not null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| if(metadata.ActivityRegularizer is not null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| // TODO(Rinne): `_is_feature_layer` | |||
| if(metadata.Stateful is not null) | |||
| { | |||
| revived_obj.stateful = metadata.Stateful.Value; | |||
| } | |||
| return (revived_obj, ReviveUtils._revive_setter); | |||
| } | |||
| private RevivedConfig _config = null; | |||
| public RevivedLayer(LayerArgs args): base(args) | |||
| { | |||
| } | |||
| public override string ToString() | |||
| { | |||
| return $"Customized keras layer: {Name}."; | |||
| } | |||
| public override IKerasConfig get_config() | |||
| { | |||
| return _config; | |||
| } | |||
| } | |||
| } | |||
| @@ -23,6 +23,7 @@ using System.Data; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using System.Reflection; | |||
| using System.Security.AccessControl; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Layers; | |||
| @@ -60,6 +61,10 @@ namespace Tensorflow.Keras.Utils | |||
| public static Layer deserialize_keras_object(string class_name, JToken config) | |||
| { | |||
| var argType = Assembly.Load("Tensorflow.Binding").GetType($"Tensorflow.Keras.ArgsDefinition.{class_name}Args"); | |||
| if(argType is null) | |||
| { | |||
| return null; | |||
| } | |||
| var deserializationMethod = typeof(JToken).GetMethods(BindingFlags.Instance | BindingFlags.Public) | |||
| .Single(x => x.Name == "ToObject" && x.IsGenericMethodDefinition && x.GetParameters().Count() == 0); | |||
| var deserializationGenericMethod = deserializationMethod.MakeGenericMethod(argType); | |||
| @@ -72,6 +77,10 @@ namespace Tensorflow.Keras.Utils | |||
| public static Layer deserialize_keras_object(string class_name, LayerArgs args) | |||
| { | |||
| var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null); | |||
| if (layer is null) | |||
| { | |||
| return null; | |||
| } | |||
| Debug.Assert(layer is Layer); | |||
| return layer as Layer; | |||
| } | |||
| @@ -1,10 +1,12 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using System.Linq; | |||
| using Tensorflow; | |||
| using Tensorflow.Keras.Optimizers; | |||
| using Tensorflow.Keras.UnitTest.Helpers; | |||
| using Tensorflow.NumPy; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.KerasApi; | |||
| namespace TensorFlowNET.Keras.UnitTest.SaveModel; | |||
| @@ -56,4 +58,11 @@ public class SequentialModelLoad | |||
| model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs); | |||
| } | |||
| [TestMethod] | |||
| public void Temp() | |||
| { | |||
| var model = tf.keras.models.load_model(@"D:\development\tf.net\tf_test\python_func"); | |||
| model.summary(); | |||
| } | |||
| } | |||