| @@ -14,6 +14,7 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using Google.Protobuf; | |||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -45,6 +46,23 @@ namespace Tensorflow | |||||
| { | { | ||||
| return as_text(bytes_or_text, encoding); | return as_text(bytes_or_text, encoding); | ||||
| } | } | ||||
| public ByteString as_bytes(ByteString bytes, Encoding encoding = null) | |||||
| { | |||||
| return bytes; | |||||
| } | |||||
| public ByteString as_bytes(byte[] bytes, Encoding encoding = null) | |||||
| { | |||||
| return ByteString.CopyFrom(bytes); | |||||
| } | |||||
| public ByteString as_bytes(string text, Encoding encoding = null) | |||||
| { | |||||
| if(encoding is null) | |||||
| { | |||||
| encoding = Encoding.UTF8; | |||||
| } | |||||
| return ByteString.CopyFrom(encoding.GetBytes(text)); | |||||
| } | |||||
| } | } | ||||
| public bool executing_eagerly() | public bool executing_eagerly() | ||||
| @@ -54,6 +54,6 @@ namespace Tensorflow | |||||
| Dictionary<string, Tensor> input_map = null, | Dictionary<string, Tensor> input_map = null, | ||||
| string[] return_elements = null, | string[] return_elements = null, | ||||
| string name = null, | string name = null, | ||||
| OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name, producer_op_list); | |||||
| OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name: name, producer_op_list: producer_op_list); | |||||
| } | } | ||||
| } | } | ||||
| @@ -156,6 +156,12 @@ namespace Tensorflow.Contexts | |||||
| return has_graph_arg; | return has_graph_arg; | ||||
| } | } | ||||
| public bool has_function(string name) | |||||
| { | |||||
| ensure_initialized(); | |||||
| return c_api.TFE_ContextHasFunction(_handle, name); | |||||
| } | |||||
| public void restore_mode() | public void restore_mode() | ||||
| { | { | ||||
| context_switches.Pop(); | context_switches.Pop(); | ||||
| @@ -0,0 +1,288 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Diagnostics; | |||||
| using System.Security.Cryptography; | |||||
| using System.Text; | |||||
| using Tensorflow.Graphs; | |||||
| using static Tensorflow.Binding; | |||||
| using static Tensorflow.CppShapeInferenceResult.Types; | |||||
| namespace Tensorflow.Framework | |||||
| { | |||||
| public class function_def_lib | |||||
| { | |||||
| // TODO(Rinne): process signatures and structured outputs. | |||||
| public static FuncGraph function_def_to_graph(FunctionDef fdef, object? structured_input_signature, | |||||
| object? structured_outputs, List<TensorShapeProto> input_shapes = null) | |||||
| { | |||||
| var func_graph = new FuncGraph(fdef.Signature.Name); | |||||
| if(input_shapes is null) | |||||
| { | |||||
| if(fdef.Attr.TryGetValue("_input_shapes", out var input_shapes_attr)) | |||||
| { | |||||
| var raw_input_shapes = input_shapes_attr.List.Shape; | |||||
| input_shapes = new List<TensorShapeProto>(); | |||||
| foreach(var (input_shape, arg_def) in raw_input_shapes.Zip(fdef.Signature.InputArg, (x, y) => (x, y))) | |||||
| { | |||||
| if(arg_def.Type == DataType.DtResource && arg_def.HandleData is not null && arg_def.HandleData.Count > 0) | |||||
| { | |||||
| input_shapes.Add(null); | |||||
| } | |||||
| else | |||||
| { | |||||
| input_shapes.Add(input_shape); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| var (graph_def, nested_to_flat_tensor_name) = function_def_to_graph_def(fdef, input_shapes); | |||||
| func_graph.as_default(); | |||||
| importer.import_graph_def(graph_def, name: "", validate_colocation_constraints: false); | |||||
| var input_tensor_names = fdef.Signature.InputArg.Select(x => nested_to_flat_tensor_name[x.Name]); | |||||
| func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); | |||||
| var output_tensor_names = fdef.Signature.OutputArg.Select(x => nested_to_flat_tensor_name[fdef.Ret[x.Name]]); | |||||
| func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); | |||||
| // TODO(Rinne): func_graph.ControlOutputs | |||||
| _set_handle_data(func_graph, fdef); | |||||
| foreach(var node in graph_def.Node) | |||||
| { | |||||
| if(node.Attr.TryGetValue("_output_shapes", out var output_shapes)) | |||||
| { | |||||
| var op = func_graph.get_operation_by_name(node.Name); | |||||
| foreach(var (output_index, shape) in enumerate(output_shapes.List.Shape.Take(op.outputs.Length))) | |||||
| { | |||||
| op.outputs[output_index].shape = new Shape(shape); | |||||
| } | |||||
| } | |||||
| } | |||||
| Dictionary<long, string> output_names = new(); | |||||
| foreach(var (ret_arg_def, tensor_name) in zip(fdef.Signature.OutputArg, output_tensor_names)) | |||||
| { | |||||
| output_names[ops.tensor_id(func_graph.get_tensor_by_name(tensor_name))] = ret_arg_def.Name; | |||||
| } | |||||
| // TODO(Rinne): func_graph._output_names = output_names | |||||
| func_graph.Exit(); | |||||
| return func_graph; | |||||
| } | |||||
| public static (GraphDef, Dictionary<string, string>) function_def_to_graph_def(FunctionDef fdef, List<TensorShapeProto> input_shapes) | |||||
| { | |||||
| var graph_def = new GraphDef() | |||||
| { | |||||
| Versions = new VersionDef() | |||||
| { | |||||
| Producer = versions.GRAPH_DEF_VERSION, | |||||
| MinConsumer = versions.GRAPH_DEF_VERSION_MIN_CONSUMER | |||||
| } | |||||
| }; | |||||
| var default_graph = ops.get_default_graph(); | |||||
| if(input_shapes is not null && input_shapes.Count > 0 && input_shapes.Count != fdef.Signature.InputArg.Count) | |||||
| { | |||||
| throw new ValueError($"Length of `input_shapes` must match the number " + | |||||
| $"of `input_arg`s in `fdef`. Got {input_shapes.Count} `input_shapes` and " + | |||||
| $"{fdef.Signature.InputArg.Count} `input_arg`s."); | |||||
| } | |||||
| foreach(var (i, arg_def) in enumerate(fdef.Signature.InputArg)) | |||||
| { | |||||
| NodeDef node_def = new(); | |||||
| node_def.Name = arg_def.Name; | |||||
| node_def.Op = "Placeholder"; | |||||
| node_def.Attr["dtype"] = new AttrValue() | |||||
| { | |||||
| Type = arg_def.Type | |||||
| }; | |||||
| if(input_shapes is not null && input_shapes.Count > 0 && input_shapes[i] is not null) | |||||
| { | |||||
| var input_shape = input_shapes[i]; | |||||
| // skip the condition that input_shape is not `TensorShapeProto`. | |||||
| AttrValue shape = new AttrValue() | |||||
| { | |||||
| Shape = new TensorShapeProto() | |||||
| }; | |||||
| shape.Shape = new TensorShapeProto(input_shape); | |||||
| node_def.Attr["shape"] = shape; | |||||
| } | |||||
| if (!fdef.ArgAttr.ContainsKey((uint)i)) | |||||
| { | |||||
| fdef.ArgAttr[(uint)i] = new FunctionDef.Types.ArgAttrs(); | |||||
| } | |||||
| var arg_attrs = fdef.ArgAttr[(uint)i].Attr; | |||||
| foreach(var k in arg_attrs.Keys) | |||||
| { | |||||
| if(k == "_output_shapes") | |||||
| { | |||||
| if (arg_attrs[k].ValueCase == AttrValue.ValueOneofCase.List) | |||||
| { | |||||
| node_def.Attr["shape"].Shape = new TensorShapeProto(arg_attrs[k].List.Shape[0]); | |||||
| } | |||||
| else if (arg_attrs[k].ValueCase == AttrValue.ValueOneofCase.Shape) | |||||
| { | |||||
| node_def.Attr["shape"].Shape = new TensorShapeProto(arg_attrs[k].Shape); | |||||
| } | |||||
| } | |||||
| else if (k.StartsWith("_")) | |||||
| { | |||||
| if (!node_def.Attr.ContainsKey(k)) | |||||
| { | |||||
| node_def.Attr[k] = new AttrValue(); | |||||
| } | |||||
| node_def.Attr[k] = new AttrValue(arg_attrs[k]); | |||||
| } | |||||
| } | |||||
| graph_def.Node.Add(node_def); | |||||
| } | |||||
| graph_def.Node.AddRange(fdef.NodeDef); | |||||
| Dictionary<string, string> nested_to_flat_tensor_name = new(); | |||||
| foreach(var arg_def in fdef.Signature.InputArg) | |||||
| { | |||||
| nested_to_flat_tensor_name[arg_def.Name] = $"{arg_def.Name}:0"; | |||||
| string control_name = "^" + arg_def.Name; | |||||
| nested_to_flat_tensor_name[control_name] = control_name; | |||||
| } | |||||
| foreach(var node_def in fdef.NodeDef) | |||||
| { | |||||
| var graph = default_graph; | |||||
| // TODO(Rinne): The `Graph` lacks `_functions`, needed to be implemented in the future. | |||||
| while(graph.OuterGraph is not null) | |||||
| { | |||||
| graph = graph.OuterGraph; | |||||
| } | |||||
| var op_def = default_graph.GetOpDef(node_def.Op); | |||||
| foreach(var attr in op_def.Attr) | |||||
| { | |||||
| if(attr.Type == "func") | |||||
| { | |||||
| var fname = node_def.Attr[attr.Name].Func.Name; | |||||
| if (!is_function(fname)) | |||||
| { | |||||
| throw new ValueError($"Function {fname} was not found. Please make sure " + | |||||
| $"the FunctionDef `fdef` is correct."); | |||||
| } | |||||
| } | |||||
| else if(attr.Type == "list(func)") | |||||
| { | |||||
| foreach(var fn in node_def.Attr[attr.Name].List.Func) | |||||
| { | |||||
| var fname = fn.Name; | |||||
| if (!is_function(fname)) | |||||
| { | |||||
| throw new ValueError($"Function {fname} was not found. Please make " + | |||||
| $"sure the FunctionDef `fdef` is correct."); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| int flattened_index = 0; | |||||
| foreach(var arg_def in op_def.OutputArg) | |||||
| { | |||||
| var num_args = _get_num_args(arg_def, node_def); | |||||
| for(int i = 0; i < num_args; i++) | |||||
| { | |||||
| var nested_name = $"{node_def.Name}:{arg_def.Name}:{i}"; | |||||
| var flat_name = $"{node_def.Name}:{flattened_index}"; | |||||
| nested_to_flat_tensor_name[nested_name] = flat_name; | |||||
| flattened_index++; | |||||
| } | |||||
| } | |||||
| string control_name = "^" + node_def.Name; | |||||
| nested_to_flat_tensor_name[control_name] = control_name; | |||||
| } | |||||
| foreach(var node_def in graph_def.Node) | |||||
| { | |||||
| for(int i = 0; i < node_def.Input.Count; i++) | |||||
| { | |||||
| node_def.Input[i] = nested_to_flat_tensor_name[node_def.Input[i]]; | |||||
| } | |||||
| } | |||||
| return (graph_def, nested_to_flat_tensor_name); | |||||
| } | |||||
| private static void _set_handle_data(FuncGraph func_graph, FunctionDef fdef) | |||||
| { | |||||
| foreach(var (tensor, arg_def) in zip(func_graph.Inputs, fdef.Signature.InputArg).Concat(zip(func_graph.Outputs, fdef.Signature.OutputArg))) | |||||
| { | |||||
| if(arg_def.HandleData is not null && arg_def.HandleData.Count > 0) | |||||
| { | |||||
| tensor.shape = Shape.Scalar; | |||||
| var shape_and_type = arg_def.HandleData[0]; | |||||
| var handle_data = new HandleData(); | |||||
| handle_data.IsSet = true; | |||||
| handle_data.ShapeAndType.Add(new HandleShapeAndType() | |||||
| { | |||||
| Shape = shape_and_type.Shape, | |||||
| Dtype = shape_and_type.Dtype | |||||
| }); | |||||
| resource_variable_ops._set_handle_shapes_and_types(tensor, handle_data, true); | |||||
| } | |||||
| } | |||||
| } | |||||
| private static long _get_num_args(OpDef.Types.ArgDef arg_def, NodeDef node_def) | |||||
| { | |||||
| if (!string.IsNullOrEmpty(arg_def.NumberAttr)) | |||||
| { | |||||
| return node_def.Attr[arg_def.NumberAttr].I; | |||||
| } | |||||
| else if(!string.IsNullOrEmpty(arg_def.TypeListAttr)) | |||||
| { | |||||
| return node_def.Attr[arg_def.TypeListAttr].List.Type.Count; | |||||
| } | |||||
| else if(arg_def.TypeAttr is not null || arg_def.Type != DataType.DtInvalid) | |||||
| { | |||||
| return 1; | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new ValueError($"Invalid arg_def:\n\n{arg_def}. Please make sure the " + | |||||
| $"FunctionDef `fdef` is correct."); | |||||
| } | |||||
| } | |||||
| public static bool is_function(string fname) | |||||
| { | |||||
| if (tf.Context.executing_eagerly()) | |||||
| { | |||||
| return tf.Context.has_function(fname); | |||||
| } | |||||
| else | |||||
| { | |||||
| var graph = ops.get_default_graph(); | |||||
| while(graph is not null) | |||||
| { | |||||
| if (graph.IsFunction(fname)) | |||||
| { | |||||
| return true; | |||||
| } | |||||
| if(graph.OuterGraph is not null) | |||||
| { | |||||
| graph = graph.OuterGraph; | |||||
| } | |||||
| else | |||||
| { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| throw new ValueError("Unexpected behavior happened in runtime, please submit an issue to " + | |||||
| "https://github.com/SciSharp/TensorFlow.NET/issues"); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -17,6 +17,7 @@ | |||||
| using Google.Protobuf; | using Google.Protobuf; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Diagnostics; | |||||
| using System.Linq; | using System.Linq; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using static Tensorflow.OpDef.Types; | using static Tensorflow.OpDef.Types; | ||||
| @@ -25,9 +26,14 @@ namespace Tensorflow | |||||
| { | { | ||||
| public class importer | public class importer | ||||
| { | { | ||||
| public static ITensorOrOperation[] import_graph_def_for_function(GraphDef graph_def, string name = null) | |||||
| { | |||||
| return import_graph_def(graph_def, validate_colocation_constraints: false, name: name); | |||||
| } | |||||
| public static ITensorOrOperation[] import_graph_def(GraphDef graph_def, | public static ITensorOrOperation[] import_graph_def(GraphDef graph_def, | ||||
| Dictionary<string, Tensor> input_map = null, | Dictionary<string, Tensor> input_map = null, | ||||
| string[] return_elements = null, | string[] return_elements = null, | ||||
| bool validate_colocation_constraints = true, | |||||
| string name = null, | string name = null, | ||||
| OpList producer_op_list = null) | OpList producer_op_list = null) | ||||
| { | { | ||||
| @@ -60,7 +66,7 @@ namespace Tensorflow | |||||
| var scoped_options = c_api_util.ScopedTFImportGraphDefOptions(); | var scoped_options = c_api_util.ScopedTFImportGraphDefOptions(); | ||||
| var status = new Status(); | var status = new Status(); | ||||
| _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); | |||||
| _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements, validate_colocation_constraints ); | |||||
| // need to create a class ImportGraphDefWithResults with IDisposal | // need to create a class ImportGraphDefWithResults with IDisposal | ||||
| results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status)); | results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status)); | ||||
| status.Check(true); | status.Check(true); | ||||
| @@ -73,6 +79,42 @@ namespace Tensorflow | |||||
| return _GatherReturnElements(return_elements, graph, results); | return _GatherReturnElements(return_elements, graph, results); | ||||
| } | } | ||||
| //private static ITensorOrOperation[] _import_graph_def_internal(GraphDef graph_def, Dictionary<string, Tensor> input_map = null, string[] return_elements = null, | |||||
| // bool validate_colocation_constraints = true, string name = null, OpList producer_op_list = null) | |||||
| //{ | |||||
| // graph_def = _ProcessGraphDefParam(graph_def); | |||||
| // input_map = _ProcessInputMapParam(input_map); | |||||
| // return_elements = _ProcessReturnElementsParam(return_elements); | |||||
| // if(producer_op_list is not null) | |||||
| // { | |||||
| // _RemoveDefaultAttrs(producer_op_list, graph_def); | |||||
| // } | |||||
| // var graph = ops.get_default_graph(); | |||||
| // string prefix = null; | |||||
| // tf_with(ops.name_scope(name, "import", input_map.Values), scope => | |||||
| // { | |||||
| // if (scope is not null) | |||||
| // { | |||||
| // Debug.Assert(scope.scope_name.EndsWith("/")); | |||||
| // prefix = scope.scope_name[scope.scope_name.Length - 1].ToString(); | |||||
| // } | |||||
| // else | |||||
| // { | |||||
| // prefix = ""; | |||||
| // } | |||||
| // input_map = _ConvertInputMapValues(name, input_map); | |||||
| // }); | |||||
| // var scope_options = c_api_util.ScopedTFImportGraphDefOptions(); | |||||
| // var options = scope_options.Options; | |||||
| // _PopulateTFImportGraphDefOptions(scope_options, prefix, input_map, return_elements, validate_colocation_constraints); | |||||
| //} | |||||
| private static ITensorOrOperation[] _GatherReturnElements(string[] requested_return_elements, | private static ITensorOrOperation[] _GatherReturnElements(string[] requested_return_elements, | ||||
| Graph graph, | Graph graph, | ||||
| TF_ImportGraphDefResults results) | TF_ImportGraphDefResults results) | ||||
| @@ -113,15 +155,29 @@ namespace Tensorflow | |||||
| public static void _PopulateTFImportGraphDefOptions(ImportGraphDefOptions options, | public static void _PopulateTFImportGraphDefOptions(ImportGraphDefOptions options, | ||||
| string prefix, | string prefix, | ||||
| Dictionary<string, Tensor> input_map, | Dictionary<string, Tensor> input_map, | ||||
| string[] return_elements) | |||||
| string[] return_elements, | |||||
| bool validate_colocation_constraints) | |||||
| { | { | ||||
| c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix); | c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix); | ||||
| c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, (char)1); | |||||
| c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options.Options, true); | |||||
| foreach (var input in input_map) | foreach (var input in input_map) | ||||
| { | { | ||||
| var (src_name, src_index) = _ParseTensorName(input.Key); | |||||
| c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name, src_index, input.Value._as_tf_output()); | |||||
| var input_src = tf.compat.as_str(input.Key); | |||||
| var input_dst = input.Value; | |||||
| if (input_src.StartsWith("^")) | |||||
| { | |||||
| var src_name = tf.compat.as_str(input_src.Substring(1)); | |||||
| var dst_op = input_dst._as_tf_output().oper; | |||||
| c_api.TF_ImportGraphDefOptionsRemapControlDependency(options.Options, src_name, dst_op); | |||||
| } | |||||
| else | |||||
| { | |||||
| var (src_name, src_index) = _ParseTensorName(input.Key); | |||||
| src_name = tf.compat.as_str(src_name); | |||||
| var dst_output = input_dst._as_tf_output(); | |||||
| c_api.TF_ImportGraphDefOptionsAddInputMapping(options.Options, src_name, src_index, dst_output); | |||||
| } | |||||
| } | } | ||||
| if (return_elements == null) | if (return_elements == null) | ||||
| @@ -132,15 +188,16 @@ namespace Tensorflow | |||||
| if (name.Contains(":")) | if (name.Contains(":")) | ||||
| { | { | ||||
| var (op_name, index) = _ParseTensorName(name); | var (op_name, index) = _ParseTensorName(name); | ||||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index); | |||||
| op_name = tf.compat.as_str(op_name); | |||||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(options.Options, op_name, index); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, name); | |||||
| c_api.TF_ImportGraphDefOptionsAddReturnOperation(options.Options, tf.compat.as_str(name)); | |||||
| } | } | ||||
| } | } | ||||
| // c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(options, validate_colocation_constraints); | |||||
| c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(options.Options, validate_colocation_constraints); | |||||
| } | } | ||||
| private static (string, int) _ParseTensorName(string tensor_name) | private static (string, int) _ParseTensorName(string tensor_name) | ||||
| @@ -173,6 +230,14 @@ namespace Tensorflow | |||||
| return graph_def; | return graph_def; | ||||
| } | } | ||||
| private static GraphDef _ProcessGraphDefParam(GraphDef graph_def) | |||||
| { | |||||
| var old_graph_def = graph_def; | |||||
| graph_def = new GraphDef(old_graph_def); | |||||
| return graph_def; | |||||
| } | |||||
| private static void _SetDefaultAttrValues(NodeDef node_def, OpDef op_def) | private static void _SetDefaultAttrValues(NodeDef node_def, OpDef op_def) | ||||
| { | { | ||||
| foreach (var attr_def in op_def.Attr) | foreach (var attr_def in op_def.Attr) | ||||
| @@ -240,6 +305,35 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| private static void _RemoveDefaultAttrs(OpList producer_op_list, GraphDef graph_def) | |||||
| { | |||||
| var producer_op_dict = producer_op_list.Op.ToDictionary(x => x.Name, x => x); | |||||
| foreach (var node in graph_def.Node) | |||||
| { | |||||
| // Remove any default attr values that aren't in op_def. | |||||
| if (producer_op_dict.ContainsKey(node.Op)) | |||||
| { | |||||
| var op_def = op_def_registry.GetOpDef(node.Op); | |||||
| if(op_def is null) | |||||
| { | |||||
| continue; | |||||
| } | |||||
| var producer_op_def = producer_op_dict[node.Op]; | |||||
| foreach (var key in node.Attr.Keys) | |||||
| { | |||||
| if (_FindAttrInOpDef(key, op_def) is null) | |||||
| { | |||||
| var attr_def = _FindAttrInOpDef(key, producer_op_def); | |||||
| if (attr_def != null && attr_def.DefaultValue != null && | |||||
| node.Attr[key] == attr_def.DefaultValue) | |||||
| node.Attr[key].ClearValue(); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| private static AttrDef _FindAttrInOpDef(string name, OpDef op_def) | private static AttrDef _FindAttrInOpDef(string name, OpDef op_def) | ||||
| { | { | ||||
| return op_def.Attr.FirstOrDefault(x => x.Name == name); | return op_def.Attr.FirstOrDefault(x => x.Name == name); | ||||
| @@ -0,0 +1,12 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Framework | |||||
| { | |||||
| public class versions | |||||
| { | |||||
| public static int GRAPH_DEF_VERSION = 1286; | |||||
| public static int GRAPH_DEF_VERSION_MIN_CONSUMER = 0; | |||||
| } | |||||
| } | |||||
| @@ -13,6 +13,7 @@ namespace Tensorflow.Functions | |||||
| /// </summary> | /// </summary> | ||||
| public class ConcreteFunction: Trackable | public class ConcreteFunction: Trackable | ||||
| { | { | ||||
| protected IEnumerable<Tensor> _captured_inputs; | |||||
| internal FuncGraph func_graph; | internal FuncGraph func_graph; | ||||
| internal ForwardBackwardCall forward_backward; | internal ForwardBackwardCall forward_backward; | ||||
| public Tensor[] Inputs => func_graph.Inputs; | public Tensor[] Inputs => func_graph.Inputs; | ||||
| @@ -29,11 +30,13 @@ namespace Tensorflow.Functions | |||||
| public ConcreteFunction(string name) | public ConcreteFunction(string name) | ||||
| { | { | ||||
| func_graph = new FuncGraph(name); | func_graph = new FuncGraph(name); | ||||
| _captured_inputs = func_graph.external_captures; | |||||
| } | } | ||||
| public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs = null) | public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs = null) | ||||
| { | { | ||||
| func_graph = graph; | func_graph = graph; | ||||
| _captured_inputs = func_graph.external_captures; | |||||
| ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray()); | ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray()); | ||||
| } | } | ||||
| @@ -53,6 +56,7 @@ namespace Tensorflow.Functions | |||||
| new[] { output }, | new[] { output }, | ||||
| null); | null); | ||||
| func_graph.Exit(); | func_graph.Exit(); | ||||
| _captured_inputs = func_graph.external_captures; | |||||
| } | } | ||||
| public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype) | public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype) | ||||
| @@ -73,6 +77,7 @@ namespace Tensorflow.Functions | |||||
| new[] { output.variant_tensor }, | new[] { output.variant_tensor }, | ||||
| null); | null); | ||||
| func_graph.Exit(); | func_graph.Exit(); | ||||
| _captured_inputs = func_graph.external_captures; | |||||
| } | } | ||||
| /*public ConcreteFunction(Func<Tensors, Tensors> func, | /*public ConcreteFunction(Func<Tensors, Tensors> func, | ||||
| @@ -174,6 +179,11 @@ namespace Tensorflow.Functions | |||||
| // TODO(Rinne); complete it with `_delayed_rewrite_functions`. | // TODO(Rinne); complete it with `_delayed_rewrite_functions`. | ||||
| } | } | ||||
| public void SetExternalCaptures(IEnumerable<Tensor> captures) | |||||
| { | |||||
| _captured_inputs = captures; | |||||
| } | |||||
| ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | ||||
| { | { | ||||
| var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | ||||
| @@ -1,4 +1,5 @@ | |||||
| using System; | using System; | ||||
| using Tensorflow.Functions; | |||||
| using Tensorflow.Train; | using Tensorflow.Train; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -0,0 +1,12 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Functions | |||||
| { | |||||
| public interface IGenericFunction | |||||
| { | |||||
| object[] Apply(params object[] args); | |||||
| ConcreteFunction get_concrete_function(params object[] args); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,88 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Operations; | |||||
| using Tensorflow.Train; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Functions | |||||
| { | |||||
| public static class function_saved_model_utils | |||||
| { | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="concrete_function"></param> | |||||
| /// <param name="inputs">a list tensors or other objects (such as variables) which | |||||
| /// contain tensors that were originally captured by the function</param> | |||||
| public static void restore_captures(ConcreteFunction concrete_function, IEnumerable<object> inputs) | |||||
| { | |||||
| var bound_inputs = inputs?.Select(obj => | |||||
| { | |||||
| if(obj is Tensor tensor) | |||||
| { | |||||
| return get_tensor_from_node(tensor); | |||||
| } | |||||
| else if(obj is IVariableV1 variable) | |||||
| { | |||||
| return get_tensor_from_node(variable); | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new TypeError("Encountered an type error, please submit an issue to " + | |||||
| "https://github.com/SciSharp/TensorFlow.NET/issues"); | |||||
| } | |||||
| }); | |||||
| var bound_variables = inputs.TakeWhile(obj => obj is IVariableV1); | |||||
| List<Tensor> captured_inputs_list = new(); | |||||
| // TODO(Rinne): concrete_function.set_variables(bound_variables) | |||||
| if (bound_inputs is not null) | |||||
| { | |||||
| foreach(var (bound_input, internal_capture) in zip(bound_inputs, concrete_function.Inputs.Skip(concrete_function.Inputs.Length - bound_inputs.Count()))) | |||||
| { | |||||
| if(hasattr(bound_input, "__tf_experimental_restore_capture__")) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| else | |||||
| { | |||||
| captured_inputs_list.Add(bound_input); | |||||
| concrete_function.func_graph.replace_capture(bound_input, internal_capture); | |||||
| if(internal_capture.dtype == dtypes.resource) | |||||
| { | |||||
| // skip the check of variable. | |||||
| handle_data_util.copy_handle_data(bound_input, internal_capture); | |||||
| } | |||||
| concrete_function.func_graph.capture(bound_input); | |||||
| } | |||||
| } | |||||
| } | |||||
| if(captured_inputs_list.Any(inp => inp is null)) | |||||
| { | |||||
| // TODO(Rinne): add warnings. | |||||
| } | |||||
| concrete_function.SetExternalCaptures(captured_inputs_list); | |||||
| } | |||||
| public static Tensor get_tensor_from_node(Tensor node) | |||||
| { | |||||
| return node; | |||||
| } | |||||
| public static Tensor get_tensor_from_node(IVariableV1 node) | |||||
| { | |||||
| if (resource_variable_ops.is_resource_variable(node)) | |||||
| { | |||||
| return node.Handle; | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new TypeError("Encountered an type error, please submit an issue to " + | |||||
| "https://github.com/SciSharp/TensorFlow.NET/issues"); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,14 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Gradients | |||||
| { | |||||
| public class custom_gradient | |||||
| { | |||||
| public static string generate_name() | |||||
| { | |||||
| return $"CustomGradient-{ops.uid()}"; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,6 +1,7 @@ | |||||
| using MethodBoundaryAspect.Fody.Attributes; | using MethodBoundaryAspect.Fody.Attributes; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.IO; | |||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using Tensorflow.Functions; | using Tensorflow.Functions; | ||||
| @@ -21,8 +22,9 @@ namespace Tensorflow.Graphs | |||||
| public override void OnEntry(MethodExecutionArgs args) | public override void OnEntry(MethodExecutionArgs args) | ||||
| { | { | ||||
| File.WriteAllText(@"D:\temp\for_test.txt", "jyfgjyfjhfjhc"); | |||||
| // TODO: func_name can be cache in FullName + Args | // TODO: func_name can be cache in FullName + Args | ||||
| func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}_{ops.uid_function()}"; | |||||
| func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}"; | |||||
| if (functions.ContainsKey(func_name)) | if (functions.ContainsKey(func_name)) | ||||
| { | { | ||||
| @@ -56,6 +56,11 @@ public class FuncGraph : Graph, IDisposable | |||||
| _handle = handle; | _handle = handle; | ||||
| } | } | ||||
| public void replace_capture(Tensor tensor, Tensor placeholder) | |||||
| { | |||||
| _captures[tensor.Id] = (tensor, placeholder); | |||||
| } | |||||
| public void ToGraph(Operation[] opers, | public void ToGraph(Operation[] opers, | ||||
| Tensor[] inputs, Tensor[] outputs, | Tensor[] inputs, Tensor[] outputs, | ||||
| string[] output_names) | string[] output_names) | ||||
| @@ -146,6 +146,12 @@ namespace Tensorflow | |||||
| return ops.set_default_graph(this); | return ops.set_default_graph(this); | ||||
| } | } | ||||
| public bool IsFunction(string name) | |||||
| { | |||||
| // TODO(Rinne): deal with `_functions`. | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| private Tensor _as_graph_element(object obj) | private Tensor _as_graph_element(object obj) | ||||
| { | { | ||||
| if (obj is RefVariable var) | if (obj is RefVariable var) | ||||
| @@ -28,6 +28,8 @@ public sealed class ImportGraphDefOptions | |||||
| _handle = c_api.TF_NewImportGraphDefOptions(); | _handle = c_api.TF_NewImportGraphDefOptions(); | ||||
| } | } | ||||
| public SafeImportGraphDefOptionsHandle Options => _handle; | |||||
| public void AddReturnOutput(string name, int index) | public void AddReturnOutput(string name, int index) | ||||
| { | { | ||||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); | c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); | ||||
| @@ -185,6 +185,9 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_ImportGraphDefOptionsAddReturnOperation(SafeImportGraphDefOptionsHandle opts, string oper_name); | public static extern void TF_ImportGraphDefOptionsAddReturnOperation(SafeImportGraphDefOptionsHandle opts, string oper_name); | ||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TF_ImportGraphDefOptionsSetValidateColocationConstraints(SafeImportGraphDefOptionsHandle options, bool validate_colocation_constraints); | |||||
| /// <summary> | /// <summary> | ||||
| /// Add an output in `graph_def` to be returned via the `return_outputs` output | /// Add an output in `graph_def` to be returned via the `return_outputs` output | ||||
| /// parameter of TF_GraphImportGraphDef(). If the output is remapped via an input | /// parameter of TF_GraphImportGraphDef(). If the output is remapped via an input | ||||
| @@ -246,7 +249,7 @@ namespace Tensorflow | |||||
| /// <param name="ops">TF_ImportGraphDefOptions*</param> | /// <param name="ops">TF_ImportGraphDefOptions*</param> | ||||
| /// <param name="uniquify_prefix">unsigned char</param> | /// <param name="uniquify_prefix">unsigned char</param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(SafeImportGraphDefOptionsHandle ops, char uniquify_prefix); | |||||
| public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(SafeImportGraphDefOptionsHandle ops, bool uniquify_prefix); | |||||
| /// <summary> | /// <summary> | ||||
| /// Fetches the return operations requested via | /// Fetches the return operations requested via | ||||
| @@ -0,0 +1,28 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Eager; | |||||
| namespace Tensorflow.Operations | |||||
| { | |||||
| public static class handle_data_util | |||||
| { | |||||
| public static void copy_handle_data(Tensor source_t, Tensor target_t) | |||||
| { | |||||
| if(target_t.dtype == dtypes.resource || target_t.dtype == dtypes.variant) | |||||
| { | |||||
| SafeTensorHandle handle_data; | |||||
| if(source_t is EagerTensor) | |||||
| { | |||||
| handle_data = source_t.Handle; | |||||
| } | |||||
| else | |||||
| { | |||||
| handle_data = ops.get_resource_handle_data(source_t); | |||||
| } | |||||
| throw new NotImplementedException(); | |||||
| //if(handle_data is not null && handle_data.) | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -126,7 +126,7 @@ namespace Tensorflow | |||||
| /// <param name="handle"></param> | /// <param name="handle"></param> | ||||
| /// <param name="handle_data"></param> | /// <param name="handle_data"></param> | ||||
| /// <param name="graph_mode"></param> | /// <param name="graph_mode"></param> | ||||
| private static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode) | |||||
| internal static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode) | |||||
| { | { | ||||
| if (!graph_mode) | if (!graph_mode) | ||||
| return; | return; | ||||
| @@ -5,6 +5,7 @@ | |||||
| #pragma warning disable 1591, 0612, 3021 | #pragma warning disable 1591, 0612, 3021 | ||||
| #region Designer generated code | #region Designer generated code | ||||
| using Tensorflow.Framework.Models; | |||||
| using pb = global::Google.Protobuf; | using pb = global::Google.Protobuf; | ||||
| using pbc = global::Google.Protobuf.Collections; | using pbc = global::Google.Protobuf.Collections; | ||||
| using pbr = global::Google.Protobuf.Reflection; | using pbr = global::Google.Protobuf.Reflection; | ||||
| @@ -2589,9 +2590,17 @@ namespace Tensorflow { | |||||
| } | } | ||||
| } | } | ||||
| #region Nested types | |||||
| /// <summary>Container for nested types declared in the FunctionSpec message type.</summary> | |||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| //public static FunctionSpec from_function_and_signature(string csharp_function, IEnumerable<TensorSpec> input_signature, bool is_pure = false, object jit_compile = null) | |||||
| //{ | |||||
| // // TODO(Rinne): _validate_signature(input_signature) | |||||
| // // TODO(Rinne): _validate_python_function(python_function, input_signature) | |||||
| //} | |||||
| #region Nested types | |||||
| /// <summary>Container for nested types declared in the FunctionSpec message type.</summary> | |||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public static partial class Types { | public static partial class Types { | ||||
| /// <summary> | /// <summary> | ||||
| /// Whether the function should be compiled by XLA. | /// Whether the function should be compiled by XLA. | ||||
| @@ -1,14 +1,24 @@ | |||||
| using System; | |||||
| using Google.Protobuf; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Diagnostics; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.CompilerServices; | |||||
| using System.Text; | using System.Text; | ||||
| using System.Text.RegularExpressions; | |||||
| using Tensorflow.Framework; | |||||
| using Tensorflow.Functions; | using Tensorflow.Functions; | ||||
| using Tensorflow.Gradients; | |||||
| using Tensorflow.Graphs; | |||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Training.Saving.SavedModel | namespace Tensorflow.Training.Saving.SavedModel | ||||
| { | { | ||||
| public static class function_deserialization | public static class function_deserialization | ||||
| { | { | ||||
| private static string _INFERENCE_PREFIX = "__inference_"; | |||||
| private static string _FUNCTION_WRAPPER_NAME_REGEX = $@"^{_INFERENCE_PREFIX}(.*)_\d+$"; | |||||
| /// <summary> | /// <summary> | ||||
| /// Creates a `Function` from a `SavedFunction`. | /// Creates a `Function` from a `SavedFunction`. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -22,6 +32,338 @@ namespace Tensorflow.Training.Saving.SavedModel | |||||
| return null; | return null; | ||||
| } | } | ||||
| public static Dictionary<string, ConcreteFunction> load_function_def_library(FunctionDefLibrary library, | |||||
| SavedObjectGraph saved_object_graph = null, string load_shared_name_suffix = null, object? wrapper_function = null) | |||||
| { | |||||
| var library_function_names = library.Function.Select(x => x.Signature.Name).Distinct(); | |||||
| Dictionary<string, ConcreteFunction> functions = new(); | |||||
| Dictionary<string, ConcreteFunction> renamed_functions = new(); | |||||
| Graph graph; | |||||
| if (ops.executing_eagerly_outside_functions()) | |||||
| { | |||||
| graph = new Graph(); | |||||
| } | |||||
| else | |||||
| { | |||||
| graph = ops.get_default_graph(); | |||||
| } | |||||
| if(load_shared_name_suffix is null) | |||||
| { | |||||
| load_shared_name_suffix = $"_load_{ops.uid()}"; | |||||
| } | |||||
| Dictionary<ByteString, string> library_gradient_names = new(); | |||||
| Dictionary<ByteString, string> new_gradient_op_types = new(); | |||||
| Dictionary<string, string> gradients_to_register = new(); | |||||
| foreach (var gdef in library.RegisteredGradients) | |||||
| { | |||||
| if(gdef.RegisteredOpType is not null) | |||||
| { | |||||
| var new_op_type = custom_gradient.generate_name(); | |||||
| var old_op_type = tf.compat.as_bytes(gdef.RegisteredOpType); | |||||
| library_gradient_names[old_op_type] = gdef.GradientFunc; | |||||
| new_gradient_op_types[old_op_type] = new_op_type; | |||||
| gradients_to_register[gdef.GradientFunc] = new_op_type; | |||||
| } | |||||
| } | |||||
| Dictionary<string, IEnumerable<string>> function_deps = new(); | |||||
| foreach(var fdef in library.Function) | |||||
| { | |||||
| function_deps[fdef.Signature.Name] = _list_function_deps(fdef, library_function_names, library_gradient_names); | |||||
| } | |||||
| Dictionary<string, ConcreteFunction> loaded_gradients = new(); | |||||
| int aa = 0; | |||||
| var temp = _sort_function_defs(library, function_deps); | |||||
| foreach (var fdef in temp) | |||||
| { | |||||
| aa++; | |||||
| var orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, new_gradient_op_types); | |||||
| if(saved_object_graph is not null && saved_object_graph.ConcreteFunctions.ContainsKey(orig_name)) | |||||
| { | |||||
| // TODO(Rinne): implement it. | |||||
| //var proto = saved_object_graph.ConcreteFunctions[orig_name]; | |||||
| //throw new NotImplementedException(); | |||||
| } | |||||
| graph.as_default(); | |||||
| var func_graph = function_def_lib.function_def_to_graph(fdef, null, null); | |||||
| graph.Exit(); | |||||
| _restore_gradient_functions(func_graph, renamed_functions, loaded_gradients); | |||||
| foreach(var dep in function_deps[orig_name]) | |||||
| { | |||||
| functions[dep].AddTograph(func_graph); | |||||
| } | |||||
| if (fdef.Attr.ContainsKey("_input_shapes")) | |||||
| { | |||||
| fdef.Attr.Remove("_input_shapes"); | |||||
| } | |||||
| var func = new ConcreteFunction(func_graph, fdef.Attr.ToDictionary(x => x.Key, x => x.Value.S.ToString())); | |||||
| if(wrapper_function is not null) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| func.AddTograph(graph); | |||||
| functions[orig_name] = func; | |||||
| renamed_functions[func.Name] = func; | |||||
| if(func_graph.get_operations().Any(op => op.op.type == "TRTEngineOp")) | |||||
| { | |||||
| func.AddTograph(ops.get_default_graph()); | |||||
| } | |||||
| if (gradients_to_register.ContainsKey(orig_name)) | |||||
| { | |||||
| var gradient_op_type = gradients_to_register[orig_name]; | |||||
| loaded_gradients[gradient_op_type] = func; | |||||
| // TODO(Rinne): deal with gradient registry. | |||||
| //new RegisteredGradient() { RegisteredOpType = gradient_op_type }. | |||||
| } | |||||
| } | |||||
| return functions; | |||||
| } | |||||
| public static void fix_node_def(NodeDef node_def, IDictionary<string, ConcreteFunction> functions, string shared_name_suffix) | |||||
| { | |||||
| if (functions.ContainsKey(node_def.Op)) | |||||
| { | |||||
| node_def.Op = functions[node_def.Op].Name; | |||||
| } | |||||
| foreach(var attr_value in node_def.Attr.Values) | |||||
| { | |||||
| if(attr_value.ValueCase == AttrValue.ValueOneofCase.Func) | |||||
| { | |||||
| attr_value.Func.Name = functions[attr_value.Func.Name].Name; | |||||
| } | |||||
| else if(attr_value.ValueCase == AttrValue.ValueOneofCase.List) | |||||
| { | |||||
| foreach(var fn in attr_value.List.Func) | |||||
| { | |||||
| fn.Name = functions[fn.Name].Name; | |||||
| } | |||||
| } | |||||
| } | |||||
| if(node_def.Op == "HashTableV2") | |||||
| { | |||||
| if(!node_def.Attr.ContainsKey("use_node_name_sharing") || !node_def.Attr["use_node_name_sharing"].B) | |||||
| { | |||||
| node_def.Attr["use_node_name_sharing"].B = true; | |||||
| shared_name_suffix += $"_{ops.uid()}"; | |||||
| } | |||||
| } | |||||
| var op_def = op_def_registry.GetOpDef(node_def.Op); | |||||
| if(op_def is not null) | |||||
| { | |||||
| var attr = op_def.Attr.Where(x => x.Name == "shared_name").FirstOrDefault(); | |||||
| if(attr is not null) | |||||
| { | |||||
| ByteString shared_name = null; | |||||
| if(node_def.Attr.ContainsKey("shared_name") && node_def.Attr["shared_name"].S is not null) | |||||
| { | |||||
| shared_name = node_def.Attr["shared_name"].S; | |||||
| } | |||||
| else if(attr.DefaultValue.S is not null) | |||||
| { | |||||
| shared_name = tf.compat.as_bytes(attr.DefaultValue.S); | |||||
| } | |||||
| if(shared_name is null) | |||||
| { | |||||
| shared_name = tf.compat.as_bytes(node_def.Name); | |||||
| } | |||||
| node_def.Attr["shared_name"].S = ByteString.CopyFrom(shared_name.Concat(tf.compat.as_bytes(node_def.Name)).ToArray()); | |||||
| } | |||||
| } | |||||
| } | |||||
| private static void _restore_gradient_functions(FuncGraph func_graph, Dictionary<string, ConcreteFunction> renamed_functions, Dictionary<string, ConcreteFunction> loaded_gradients) | |||||
| { | |||||
| foreach(var op in func_graph.get_operations()) | |||||
| { | |||||
| if(op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall") | |||||
| { | |||||
| var function = renamed_functions[tf.compat.as_bytes(op.op.node_def.Attr["f"].Func.Name).ToString()]; | |||||
| // TODO(Rinne): deal with `op._gradient_function`. | |||||
| } | |||||
| string gradient_op_type = null; | |||||
| try | |||||
| { | |||||
| gradient_op_type = op.op.get_attr("_gradient_op_type") as string; | |||||
| } | |||||
| catch(Exception e) | |||||
| { | |||||
| continue; | |||||
| } | |||||
| if (loaded_gradients.ContainsKey(gradient_op_type)) | |||||
| { | |||||
| var grad_fn = loaded_gradients[gradient_op_type]; | |||||
| grad_fn.NumPositionArgs = op.op.inputs.Length; | |||||
| grad_fn.ArgKeywords = op.op.inputs._inputs.Select(x => x.name); | |||||
| } | |||||
| } | |||||
| } | |||||
| private static string _fix_fdef_in_place(FunctionDef fdef, IDictionary<string, ConcreteFunction> functions, string shared_name_suffix, | |||||
| IDictionary<ByteString, string> new_gradient_op_types) | |||||
| { | |||||
| var orig_name = fdef.Signature.Name; | |||||
| bool contains_unsaved_custom_gradients = false; | |||||
| foreach(var node_def in fdef.NodeDef) | |||||
| { | |||||
| fix_node_def(node_def, functions, shared_name_suffix); | |||||
| var op_type = _get_gradient_op_type(node_def); | |||||
| if(op_type is not null) | |||||
| { | |||||
| if (new_gradient_op_types.ContainsKey(op_type)) | |||||
| { | |||||
| node_def.Attr["_gradient_op_type"].S = tf.compat.as_bytes(new_gradient_op_types[op_type]); | |||||
| } | |||||
| else | |||||
| { | |||||
| contains_unsaved_custom_gradients = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (contains_unsaved_custom_gradients) | |||||
| { | |||||
| // TODO(Rinne): log warnings. | |||||
| } | |||||
| fdef.Signature.Name = _clean_function_name(fdef.Signature.Name); | |||||
| return orig_name; | |||||
| } | |||||
| private static string _clean_function_name(string name) | |||||
| { | |||||
| var match = Regex.Match(name, _FUNCTION_WRAPPER_NAME_REGEX); | |||||
| if(match.Success) | |||||
| { | |||||
| return match.Groups[1].Value; | |||||
| } | |||||
| else | |||||
| { | |||||
| return name; | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Return a topologic sort of FunctionDefs in a library. | |||||
| /// </summary> | |||||
| /// <param name="library"></param> | |||||
| /// <param name="function_deps"></param> | |||||
| private static IEnumerable<FunctionDef> _sort_function_defs(FunctionDefLibrary library, Dictionary<string, IEnumerable<string>> function_deps) | |||||
| { | |||||
| Dictionary<string, IList<string>> edges = new(); | |||||
| Dictionary<string, int> in_count = new(); | |||||
| foreach(var item in function_deps) | |||||
| { | |||||
| var fname = item.Key; | |||||
| var deps = item.Value; | |||||
| if(deps is null || deps.Count() == 0) | |||||
| { | |||||
| in_count[fname] = 0; | |||||
| continue; | |||||
| } | |||||
| foreach(var dep in deps) | |||||
| { | |||||
| edges.SetDefault(dep, new List<string>()).Add(fname); | |||||
| if (in_count.ContainsKey(fname)) | |||||
| { | |||||
| in_count[fname]++; | |||||
| } | |||||
| else | |||||
| { | |||||
| in_count[fname] = 1; | |||||
| } | |||||
| } | |||||
| } | |||||
| var ready = new Stack<string>(library.Function. | |||||
| Where(x => in_count[x.Signature.Name] == 0) | |||||
| .Select(x => x.Signature.Name).ToList()); | |||||
| List<string> output = new(); | |||||
| while(ready.Count > 0) | |||||
| { | |||||
| var node = ready.Pop(); | |||||
| output.Add(node); | |||||
| if (!edges.ContainsKey(node)) | |||||
| { | |||||
| continue; | |||||
| } | |||||
| foreach(var dest in edges[node]) | |||||
| { | |||||
| in_count[dest] -= 1; | |||||
| if (in_count[dest] == 0) | |||||
| { | |||||
| ready.Push(dest); | |||||
| } | |||||
| } | |||||
| } | |||||
| if(output.Count != library.Function.Count) | |||||
| { | |||||
| var failed_to_resolve = in_count.Keys.Except(output); | |||||
| throw new ValueError($"There is a cyclic dependency between functions. " + | |||||
| $"Could not resolve ({string.Join(", ", failed_to_resolve)})."); | |||||
| } | |||||
| var reverse = library.Function.ToDictionary(x => x.Signature.Name, x => x); | |||||
| return output.Select(x => reverse[x]); | |||||
| } | |||||
| private static IEnumerable<string> _list_function_deps(FunctionDef fdef, IEnumerable<string> library_function_names, IDictionary<ByteString, string> library_gradient_names) | |||||
| { | |||||
| HashSet<string> deps = new HashSet<string>(); | |||||
| foreach(var node_def in fdef.NodeDef) | |||||
| { | |||||
| var grad_op_type = _get_gradient_op_type(node_def); | |||||
| if (library_function_names.Contains(node_def.Op)) | |||||
| { | |||||
| deps.Add(node_def.Op); | |||||
| } | |||||
| else if(grad_op_type is not null && library_gradient_names.TryGetValue(grad_op_type, out var gradient_name)) | |||||
| { | |||||
| deps.Add(gradient_name); | |||||
| } | |||||
| else | |||||
| { | |||||
| foreach(var attr_value in node_def.Attr.Values) | |||||
| { | |||||
| if(attr_value.ValueCase == AttrValue.ValueOneofCase.Func) | |||||
| { | |||||
| deps.Add(attr_value.Func.Name); | |||||
| } | |||||
| else if(attr_value.ValueCase == AttrValue.ValueOneofCase.List) | |||||
| { | |||||
| foreach(var fn in attr_value.List.Func) | |||||
| { | |||||
| deps.Add(fn.Name); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return deps.AsEnumerable(); | |||||
| } | |||||
| private static ByteString _get_gradient_op_type(NodeDef node_def) | |||||
| { | |||||
| if(node_def.Attr.ContainsKey("_gradient_op_type") && node_def.Op != "StatefulPartitionedCall" && node_def.Op != "PartitionedCall") | |||||
| { | |||||
| return node_def.Attr["_gradient_op_type"].S; | |||||
| } | |||||
| return null; | |||||
| } | |||||
| public static ConcreteFunction setup_bare_concrete_function(SavedBareConcreteFunction saved_bare_concrete_function, | public static ConcreteFunction setup_bare_concrete_function(SavedBareConcreteFunction saved_bare_concrete_function, | ||||
| IDictionary<string, ConcreteFunction> concrete_functions) | IDictionary<string, ConcreteFunction> concrete_functions) | ||||
| { | { | ||||
| @@ -30,6 +372,7 @@ namespace Tensorflow.Training.Saving.SavedModel | |||||
| concrete_function.NumPositionArgs = saved_bare_concrete_function.AllowedPositionalArguments; | concrete_function.NumPositionArgs = saved_bare_concrete_function.AllowedPositionalArguments; | ||||
| var function_spec = _deserialize_function_spec_as_nonmethod(saved_bare_concrete_function.FunctionSpec); | var function_spec = _deserialize_function_spec_as_nonmethod(saved_bare_concrete_function.FunctionSpec); | ||||
| // TODO(Rinne): set the functiona spec. | |||||
| concrete_function.AddTograph(); | concrete_function.AddTograph(); | ||||
| return concrete_function; | return concrete_function; | ||||
| } | } | ||||
| @@ -35,6 +35,8 @@ namespace Tensorflow | |||||
| private Dictionary<int, (Trackable, Action<object, object, object>)> _loaded_nodes; | private Dictionary<int, (Trackable, Action<object, object, object>)> _loaded_nodes; | ||||
| private List<Trackable> _nodes; | private List<Trackable> _nodes; | ||||
| private Dictionary<int, Action<object, object, object>> _node_setters; | private Dictionary<int, Action<object, object, object>> _node_setters; | ||||
| private Dictionary<string, ConcreteFunction> _concrete_functions; | |||||
| private HashSet<string> _restored_concrete_functions; | |||||
| public Loader(SavedObjectGraph object_graph_proto, SavedModel saved_model_proto, string export_dir, | public Loader(SavedObjectGraph object_graph_proto, SavedModel saved_model_proto, string export_dir, | ||||
| CheckpointOptions ckpt_options, LoadOptions save_options, IDictionary<string, (Trackable, Action<object, object, object>)> filters) | CheckpointOptions ckpt_options, LoadOptions save_options, IDictionary<string, (Trackable, Action<object, object, object>)> filters) | ||||
| { | { | ||||
| @@ -44,6 +46,9 @@ namespace Tensorflow | |||||
| _proto = object_graph_proto; | _proto = object_graph_proto; | ||||
| _export_dir = export_dir; | _export_dir = export_dir; | ||||
| // TODO: `this._concrete_functions` and `this._restored_concrete_functions` | // TODO: `this._concrete_functions` and `this._restored_concrete_functions` | ||||
| _concrete_functions = function_deserialization.load_function_def_library( | |||||
| meta_graph.GraphDef.Library, _proto); | |||||
| _restored_concrete_functions = new HashSet<string>(); | |||||
| _checkpoint_options = ckpt_options; | _checkpoint_options = ckpt_options; | ||||
| _save_options = save_options; | _save_options = save_options; | ||||
| @@ -464,9 +469,17 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| private void _setup_function_captures() | |||||
| private void _setup_function_captures(string concrete_function_name, Dictionary<Maybe<string, int>, Trackable> nodes) | |||||
| { | { | ||||
| // TODO: implement it with concrete functions. | |||||
| if (_restored_concrete_functions.Contains(concrete_function_name)) | |||||
| { | |||||
| return; | |||||
| } | |||||
| _restored_concrete_functions.Add(concrete_function_name); | |||||
| var concrete_function = _concrete_functions[concrete_function_name]; | |||||
| var proto = _proto.ConcreteFunctions[concrete_function_name]; | |||||
| var inputs = proto.BoundInputs.Select(x => nodes[x]); | |||||
| function_saved_model_utils.restore_captures(concrete_function, inputs); | |||||
| } | } | ||||
| private void _setup_remaining_functions() | private void _setup_remaining_functions() | ||||
| @@ -625,7 +638,7 @@ namespace Tensorflow | |||||
| var fn = function_deserialization.recreate_function(proto, null); | var fn = function_deserialization.recreate_function(proto, null); | ||||
| foreach (var name in proto.ConcreteFunctions) | foreach (var name in proto.ConcreteFunctions) | ||||
| { | { | ||||
| _setup_function_captures(); | |||||
| _setup_function_captures(name, dependencies); | |||||
| } | } | ||||
| return (fn, setattr); | return (fn, setattr); | ||||
| } | } | ||||
| @@ -633,8 +646,9 @@ namespace Tensorflow | |||||
| private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, | private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, | ||||
| Dictionary<Maybe<string, int>, Trackable> dependencies) | Dictionary<Maybe<string, int>, Trackable> dependencies) | ||||
| { | { | ||||
| throw new NotImplementedException(); | |||||
| //var fn = function_deserialization.setup_bare_concrete_function(proto, ) | |||||
| var fn = function_deserialization.setup_bare_concrete_function(proto, _concrete_functions); | |||||
| _setup_function_captures(proto.ConcreteFunctionName, dependencies); | |||||
| return (fn, setattr); | |||||
| } | } | ||||
| // TODO: remove this to a common class. | // TODO: remove this to a common class. | ||||
| @@ -0,0 +1,14 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Training.Saving.SavedModel | |||||
| { | |||||
| //public class nested_structure_coder | |||||
| //{ | |||||
| // public static List<object> decode_proto(StructuredValue proto) | |||||
| // { | |||||
| // return proto s | |||||
| // } | |||||
| //} | |||||
| } | |||||
| @@ -572,6 +572,11 @@ namespace Tensorflow | |||||
| return get_default_graph().building_function; | return get_default_graph().building_function; | ||||
| } | } | ||||
| public static SafeTensorHandle get_resource_handle_data(Tensor graph_op) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public static void dismantle_graph(Graph graph) | public static void dismantle_graph(Graph graph) | ||||
| { | { | ||||
| @@ -8,9 +8,14 @@ namespace Tensorflow.Keras.Saving | |||||
| { | { | ||||
| public class KerasMetaData | public class KerasMetaData | ||||
| { | { | ||||
| [JsonProperty("name")] | |||||
| public string Name { get; set; } | public string Name { get; set; } | ||||
| [JsonProperty("class_name")] | [JsonProperty("class_name")] | ||||
| public string ClassName { get; set; } | public string ClassName { get; set; } | ||||
| [JsonProperty("trainable")] | |||||
| public bool Trainable { get; set; } | |||||
| [JsonProperty("dtype")] | |||||
| public TF_DataType DType { get; set; } = TF_DataType.DtInvalid; | |||||
| [JsonProperty("is_graph_network")] | [JsonProperty("is_graph_network")] | ||||
| public bool IsGraphNetwork { get; set; } | public bool IsGraphNetwork { get; set; } | ||||
| [JsonProperty("shared_object_id")] | [JsonProperty("shared_object_id")] | ||||
| @@ -20,5 +25,13 @@ namespace Tensorflow.Keras.Saving | |||||
| public JObject Config { get; set; } | public JObject Config { get; set; } | ||||
| [JsonProperty("build_input_shape")] | [JsonProperty("build_input_shape")] | ||||
| public TensorShapeConfig BuildInputShape { get; set; } | public TensorShapeConfig BuildInputShape { get; set; } | ||||
| [JsonProperty("batch_input_shape")] | |||||
| public TensorShapeConfig BatchInputShape { get; set; } | |||||
| [JsonProperty("activity_regularizer")] | |||||
| public IRegularizer ActivityRegularizer { get; set; } | |||||
| [JsonProperty("input_spec")] | |||||
| public JToken InputSpec { get; set; } | |||||
| [JsonProperty("stateful")] | |||||
| public bool? Stateful { get; set; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -26,7 +26,7 @@ namespace Tensorflow.Keras.Saving | |||||
| { | { | ||||
| public class KerasObjectLoader | public class KerasObjectLoader | ||||
| { | { | ||||
| private static readonly IDictionary<string, Trackable> PUBLIC_ATTRIBUTES = new CommonEndPoints().CheckpointableObjects; | |||||
| internal static readonly IDictionary<string, Trackable> PUBLIC_ATTRIBUTES = new CommonEndPoints().CheckpointableObjects; | |||||
| private SavedMetadata _metadata; | private SavedMetadata _metadata; | ||||
| private SavedObjectGraph _proto; | private SavedObjectGraph _proto; | ||||
| private Dictionary<int, string> _node_paths = new Dictionary<int, string>(); | private Dictionary<int, string> _node_paths = new Dictionary<int, string>(); | ||||
| @@ -311,6 +311,10 @@ namespace Tensorflow.Keras.Saving | |||||
| { | { | ||||
| (obj, setter) = _revive_custom_object(identifier, metadata); | (obj, setter) = _revive_custom_object(identifier, metadata); | ||||
| } | } | ||||
| if(obj is null) | |||||
| { | |||||
| throw new ValueError($"Cannot revive {metadata.Name} from the config or customized object."); | |||||
| } | |||||
| Debug.Assert(obj is Layer); | Debug.Assert(obj is Layer); | ||||
| _maybe_add_serialized_attributes(obj as Layer, metadata); | _maybe_add_serialized_attributes(obj as Layer, metadata); | ||||
| return (obj, setter); | return (obj, setter); | ||||
| @@ -349,8 +353,14 @@ namespace Tensorflow.Keras.Saving | |||||
| private (Trackable, Action<object, object, object>) _revive_custom_object(string identifier, KerasMetaData metadata) | private (Trackable, Action<object, object, object>) _revive_custom_object(string identifier, KerasMetaData metadata) | ||||
| { | { | ||||
| // TODO(Rinne): implement it. | |||||
| throw new NotImplementedException(); | |||||
| if(identifier == SavedModel.Constants.LAYER_IDENTIFIER) | |||||
| { | |||||
| return RevivedLayer.init_from_metadata(metadata); | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | } | ||||
| Model _revive_graph_network(string identifier, KerasMetaData metadata, int node_id) | Model _revive_graph_network(string identifier, KerasMetaData metadata, int node_id) | ||||
| @@ -403,9 +413,13 @@ namespace Tensorflow.Keras.Saving | |||||
| var obj = generic_utils.deserialize_keras_object(class_name, config); | var obj = generic_utils.deserialize_keras_object(class_name, config); | ||||
| if(obj is null) | |||||
| { | |||||
| return null; | |||||
| } | |||||
| obj.Name = metadata.Name; | obj.Name = metadata.Name; | ||||
| // TODO(Rinne): add `trainable`, `dtype`, `stateful` and `save_spec` | // TODO(Rinne): add `trainable`, `dtype`, `stateful` and `save_spec` | ||||
| var built = _try_build_layer(obj, node_id, metadata.BuildInputShape); | var built = _try_build_layer(obj, node_id, metadata.BuildInputShape); | ||||
| if (!built) | if (!built) | ||||
| @@ -0,0 +1,62 @@ | |||||
| using Newtonsoft.Json.Linq; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Diagnostics; | |||||
| using System.Text; | |||||
| using System.Text.RegularExpressions; | |||||
| using Tensorflow.Keras.Engine; | |||||
| using Tensorflow.Keras.Utils; | |||||
| using Tensorflow.Train; | |||||
| namespace Tensorflow.Keras.Saving.SavedModel | |||||
| { | |||||
| internal static class ReviveUtils | |||||
| { | |||||
| public static T recursively_deserialize_keras_object<T>(JToken config) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| if(config is JObject jobject) | |||||
| { | |||||
| if (jobject.ContainsKey("class_name")) | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| public static void _revive_setter(object layer, object name, object value) | |||||
| { | |||||
| Debug.Assert(name is string); | |||||
| Debug.Assert(layer is Layer); | |||||
| if (KerasObjectLoader.PUBLIC_ATTRIBUTES.ContainsKey(name as string)) | |||||
| { | |||||
| if (value is Trackable trackable) | |||||
| { | |||||
| (layer as Layer)._track_trackable(trackable, name as string); | |||||
| } | |||||
| (layer as Layer).SerializedAttributes[name] = JToken.FromObject(value); | |||||
| } | |||||
| else if (layer is Functional functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success) | |||||
| { | |||||
| Debug.Assert(value is Trackable); | |||||
| functional._track_trackable(value as Trackable, name as string); | |||||
| } | |||||
| else | |||||
| { | |||||
| var properties = layer.GetType().GetProperties(); | |||||
| foreach (var p in properties) | |||||
| { | |||||
| if ((string)name == p.Name) | |||||
| { | |||||
| if(p.GetValue(layer) is not null) | |||||
| { | |||||
| return; | |||||
| } | |||||
| p.SetValue(layer, value); | |||||
| return; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,37 @@ | |||||
| using Newtonsoft.Json; | |||||
| using Newtonsoft.Json.Linq; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.Saving.SavedModel | |||||
| { | |||||
| [JsonConverter(typeof(CustomizedRevivedConfigJsonConverter))] | |||||
| public class RevivedConfig: IKerasConfig | |||||
| { | |||||
| public JObject Config { get; set; } | |||||
| } | |||||
| public class CustomizedRevivedConfigJsonConverter : JsonConverter | |||||
| { | |||||
| public override bool CanConvert(Type objectType) | |||||
| { | |||||
| return objectType == typeof(RevivedConfig); | |||||
| } | |||||
| public override bool CanRead => true; | |||||
| public override bool CanWrite => true; | |||||
| public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | |||||
| { | |||||
| ((RevivedConfig)value).Config.WriteTo(writer); | |||||
| } | |||||
| public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||||
| { | |||||
| var config = (JObject)serializer.Deserialize(reader, typeof(JObject)); | |||||
| return new RevivedConfig() { Config = config }; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,73 @@ | |||||
| using Newtonsoft.Json.Linq; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Runtime.CompilerServices; | |||||
| using System.Text; | |||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.Engine; | |||||
| using Tensorflow.Keras.Utils; | |||||
| using Tensorflow.Keras.Saving.SavedModel; | |||||
| namespace Tensorflow.Keras.Saving.SavedModel | |||||
| { | |||||
| public class RevivedLayer: Layer | |||||
| { | |||||
| public static (RevivedLayer, Action<object, object, object>) init_from_metadata(KerasMetaData metadata) | |||||
| { | |||||
| LayerArgs args = new LayerArgs() | |||||
| { | |||||
| Name = metadata.Name, | |||||
| Trainable = metadata.Trainable | |||||
| }; | |||||
| if(metadata.DType != TF_DataType.DtInvalid) | |||||
| { | |||||
| args.DType = metadata.DType; | |||||
| } | |||||
| if(metadata.BatchInputShape is not null) | |||||
| { | |||||
| args.BatchInputShape = metadata.BatchInputShape; | |||||
| } | |||||
| RevivedLayer revived_obj = new RevivedLayer(args); | |||||
| // TODO(Rinne): implement `expects_training_arg`. | |||||
| var config = metadata.Config; | |||||
| if (generic_utils.validate_config(config)) | |||||
| { | |||||
| revived_obj._config = new RevivedConfig() { Config = config }; | |||||
| } | |||||
| if(metadata.InputSpec is not null) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| if(metadata.ActivityRegularizer is not null) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| // TODO(Rinne): `_is_feature_layer` | |||||
| if(metadata.Stateful is not null) | |||||
| { | |||||
| revived_obj.stateful = metadata.Stateful.Value; | |||||
| } | |||||
| return (revived_obj, ReviveUtils._revive_setter); | |||||
| } | |||||
| private RevivedConfig _config = null; | |||||
| public RevivedLayer(LayerArgs args): base(args) | |||||
| { | |||||
| } | |||||
| public override string ToString() | |||||
| { | |||||
| return $"Customized keras layer: {Name}."; | |||||
| } | |||||
| public override IKerasConfig get_config() | |||||
| { | |||||
| return _config; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -23,6 +23,7 @@ using System.Data; | |||||
| using System.Diagnostics; | using System.Diagnostics; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Reflection; | using System.Reflection; | ||||
| using System.Security.AccessControl; | |||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Layers; | using Tensorflow.Keras.Layers; | ||||
| @@ -60,6 +61,10 @@ namespace Tensorflow.Keras.Utils | |||||
| public static Layer deserialize_keras_object(string class_name, JToken config) | public static Layer deserialize_keras_object(string class_name, JToken config) | ||||
| { | { | ||||
| var argType = Assembly.Load("Tensorflow.Binding").GetType($"Tensorflow.Keras.ArgsDefinition.{class_name}Args"); | var argType = Assembly.Load("Tensorflow.Binding").GetType($"Tensorflow.Keras.ArgsDefinition.{class_name}Args"); | ||||
| if(argType is null) | |||||
| { | |||||
| return null; | |||||
| } | |||||
| var deserializationMethod = typeof(JToken).GetMethods(BindingFlags.Instance | BindingFlags.Public) | var deserializationMethod = typeof(JToken).GetMethods(BindingFlags.Instance | BindingFlags.Public) | ||||
| .Single(x => x.Name == "ToObject" && x.IsGenericMethodDefinition && x.GetParameters().Count() == 0); | .Single(x => x.Name == "ToObject" && x.IsGenericMethodDefinition && x.GetParameters().Count() == 0); | ||||
| var deserializationGenericMethod = deserializationMethod.MakeGenericMethod(argType); | var deserializationGenericMethod = deserializationMethod.MakeGenericMethod(argType); | ||||
| @@ -72,6 +77,10 @@ namespace Tensorflow.Keras.Utils | |||||
| public static Layer deserialize_keras_object(string class_name, LayerArgs args) | public static Layer deserialize_keras_object(string class_name, LayerArgs args) | ||||
| { | { | ||||
| var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null); | var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null); | ||||
| if (layer is null) | |||||
| { | |||||
| return null; | |||||
| } | |||||
| Debug.Assert(layer is Layer); | Debug.Assert(layer is Layer); | ||||
| return layer as Layer; | return layer as Layer; | ||||
| } | } | ||||
| @@ -1,10 +1,12 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using System; | |||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Keras.Optimizers; | using Tensorflow.Keras.Optimizers; | ||||
| using Tensorflow.Keras.UnitTest.Helpers; | using Tensorflow.Keras.UnitTest.Helpers; | ||||
| using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using static Tensorflow.KerasApi; | |||||
| namespace TensorFlowNET.Keras.UnitTest.SaveModel; | namespace TensorFlowNET.Keras.UnitTest.SaveModel; | ||||
| @@ -56,4 +58,11 @@ public class SequentialModelLoad | |||||
| model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs); | model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void Temp() | |||||
| { | |||||
| var model = tf.keras.models.load_model(@"D:\development\tf.net\tf_test\python_func"); | |||||
| model.summary(); | |||||
| } | |||||
| } | } | ||||