| @@ -14,6 +14,7 @@ namespace Tensorflow.Functions | |||
| { | |||
| IntPtr _handle; | |||
| FuncGraph func_graph; | |||
| public Tensor[] Inputs => func_graph.Inputs; | |||
| public Tensor[] CapturedInputs => func_graph.external_captures; | |||
| public string Name | |||
| @@ -127,30 +128,53 @@ namespace Tensorflow.Functions | |||
| func_graph.Exit(); | |||
| } | |||
| public Tensors Invoke(Tensors inputs) | |||
| public Tensors FilteredCall(Tensors inputs) | |||
| { | |||
| var forward_backward = SelectForwardAndBackwardFunctions(inputs, 1, tf.Context.executing_eagerly()); | |||
| var (forward_function, args_with_tangents) = forward_backward.Forward(); | |||
| Tensors flat_outputs = null; | |||
| if (tf.Context.executing_eagerly()) | |||
| flat_outputs = forward_function.Call(args_with_tangents); | |||
| forward_backward.Record(flat_outputs); | |||
| return flat_outputs; | |||
| return CallFlat(inputs, CapturedInputs); | |||
| } | |||
| /// <summary> | |||
| /// Executes the wrapped function. | |||
| /// </summary> | |||
| /// <param name="args"></param> | |||
| /// <param name="captured_inputs"></param> | |||
| /// <returns></returns> | |||
| public Tensor[] CallFlat(Tensor[] args, Tensor[] captured_inputs) | |||
| { | |||
| var new_args = new List<Tensor>(); | |||
| new_args.AddRange(args); | |||
| new_args.AddRange(captured_inputs); | |||
| args = new_args.ToArray(); | |||
| var executing_eagerly = tf.Context.executing_eagerly(); | |||
| var default_graph = ops.get_default_graph(); | |||
| var tensor_inputs = new Tensors(); | |||
| foreach (var (i, arg) in enumerate(args)) | |||
| { | |||
| tensor_inputs.Add(arg); | |||
| // If we're graph building, shape inference is on. | |||
| if (!executing_eagerly) | |||
| { | |||
| } | |||
| } | |||
| tensor_inputs.AddRange(captured_inputs); | |||
| args = tensor_inputs.ToArray(); | |||
| var attrs = new object[] | |||
| var possible_gradient_type = tf.Runner.MustRecordGradient() ? 1 : 0; | |||
| // No tape is watching; skip to running the function. | |||
| if (possible_gradient_type == 0 && executing_eagerly) | |||
| { | |||
| "executor_type", "", | |||
| "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() | |||
| }; | |||
| return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs); | |||
| var attrs = new object[] | |||
| { | |||
| "executor_type", "", | |||
| "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() | |||
| }; | |||
| return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs); | |||
| } | |||
| var forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly); | |||
| var (forward_function, args_with_tangents) = forward_backward.Forward(); | |||
| Tensors flat_outputs = null; | |||
| if (executing_eagerly) | |||
| flat_outputs = forward_function.Call(args_with_tangents); | |||
| forward_backward.Record(flat_outputs); | |||
| return flat_outputs; | |||
| } | |||
| ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | |||
| @@ -31,11 +31,17 @@ namespace Tensorflow.Functions | |||
| public Tensors Call(Tensors args) | |||
| { | |||
| var attrs = new object[] | |||
| { | |||
| "executor_type", "", | |||
| "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() | |||
| }; | |||
| var results = tf.Runner.TFE_Execute(tf.Context, | |||
| tf.Context.DeviceName, | |||
| _func_graph.FuncName, | |||
| args, | |||
| null, | |||
| attrs, | |||
| _num_outputs); | |||
| return results; | |||
| @@ -49,24 +49,61 @@ namespace Tensorflow.Functions | |||
| getBackwardFunction: () => backward_function); | |||
| } | |||
| /// <summary> | |||
| /// Create a backward function given `outputs` from the forward function. | |||
| /// </summary> | |||
| /// <param name="forward_graph"></param> | |||
| /// <param name="backward"></param> | |||
| /// <param name="outputs"></param> | |||
| /// <returns></returns> | |||
| (BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs) | |||
| { | |||
| BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | |||
| var capture_mapping = new Dictionary<long, Tensor>(); | |||
| foreach(var (i, output) in enumerate(outputs)) | |||
| capture_mapping[forward_graph.Outputs[i].Id] = output; | |||
| var remapped_captures = new Tensors(); | |||
| foreach(var capture in backward.CapturedInputs) | |||
| { | |||
| if (capture_mapping.ContainsKey(capture.Id)) | |||
| remapped_captures.Add(capture_mapping[capture.Id]); | |||
| } | |||
| var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length; | |||
| var recorded_outputs = new Tensors(); | |||
| var relevant_outputs = outputs; | |||
| var trainable_recorded_outputs = 0; | |||
| var skip_positions = new List<int>(); | |||
| foreach (var (output_index, output) in enumerate(relevant_outputs)) | |||
| { | |||
| if (trainable_recorded_outputs < backward_function_inputs) | |||
| recorded_outputs.Add(output); | |||
| if (gradients_util.IsTrainable(output)) | |||
| trainable_recorded_outputs += 1; | |||
| else | |||
| skip_positions.Add(output_index); | |||
| } | |||
| BackwardFunction _backward_function_wrapper = (args, unneeded_gradients) => | |||
| { | |||
| var processed_args = new List<Tensor>(); | |||
| var processed_args = new Tensors(); | |||
| var input_index = 0; | |||
| foreach (var (output_index, arg) in enumerate(output_grads)) | |||
| foreach (var (output_index, arg) in enumerate(args)) | |||
| { | |||
| if (arg is null) | |||
| if (skip_positions.Contains(output_index)) | |||
| continue; | |||
| if (arg == null) | |||
| throw new NotImplementedException(""); | |||
| processed_args.add(arg); | |||
| processed_args.Add(arg); | |||
| input_index += 1; | |||
| if (input_index >= backward_function_inputs) | |||
| break; | |||
| } | |||
| tf.Logger.Debug($"Invoke backward function: {backward.Name}"); | |||
| return backward.CallFlat(processed_args.ToArray(), outputs); | |||
| return backward.CallFlat(processed_args, remapped_captures); | |||
| }; | |||
| return (_backward_function_wrapper, outputs); | |||
| return (_backward_function_wrapper, recorded_outputs); | |||
| } | |||
| protected (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int) | |||
| @@ -103,7 +140,7 @@ namespace Tensorflow.Functions | |||
| } | |||
| backwards_graph.Exit(); | |||
| var forward_function_name = $"{_FORWARD_PREFIX}_{ops.uid()}"; | |||
| var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"; | |||
| var backward_function_attr = new Dictionary<string, string>(); | |||
| backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; | |||
| gradients_wrt_outputs.append(backwards_graph.internal_captures); | |||
| @@ -228,13 +228,14 @@ namespace Tensorflow.Gradients | |||
| var grad = grads[0]; | |||
| var x = op.inputs[0]; | |||
| var a = op.inputs[1]; | |||
| var size = array_ops.stack(new object[] { array_ops.rank(x), 1 }); | |||
| var pad_before = array_ops.slice(a, new[] { 0, 0 }, size); | |||
| var size = array_ops.stack(new Tensor[] { array_ops.rank(x), constant_op.constant(1) }); | |||
| var begin = constant_op.constant(new[] { 0, 0 }); | |||
| var pad_before = array_ops.slice(a, begin, size); | |||
| // Make it a 1-D tensor. | |||
| var begin = array_ops.reshape(pad_before, new[] { -1 }); | |||
| var sizes = array_ops.shape(x); | |||
| var x_grad = array_ops.slice(grad, begin, sizes); | |||
| begin = array_ops.reshape(pad_before, new[] { -1 }); | |||
| size = array_ops.shape(x); | |||
| var x_grad = array_ops.slice(grad, begin, size); | |||
| if (len(op.inputs) == 3) | |||
| return new Tensor[] { x_grad, null, null }; | |||
| @@ -30,7 +30,7 @@ namespace Tensorflow.Gradients | |||
| var shape = new TensorShape(image.shape.Skip(1).Take(2).ToArray()); | |||
| Tensor image_shape = null; | |||
| if (shape.is_fully_defined()) | |||
| throw new NotImplementedException("_ResizeNearestNeighborGrad shape.is_fully_defined"); | |||
| image_shape = constant_op.constant(image.shape[1..3]); | |||
| else | |||
| image_shape = array_ops.shape(image)["1:3"]; | |||
| @@ -8,6 +8,9 @@ using static Tensorflow.Binding; | |||
| namespace Tensorflow.Graphs | |||
| { | |||
| /// <summary> | |||
| /// func_graph.py func_graph_from_py_func | |||
| /// </summary> | |||
| [AllowChangingInputArguments] | |||
| public sealed class AutoGraphAttribute : OnMethodBoundaryAspect | |||
| { | |||
| @@ -18,15 +21,16 @@ namespace Tensorflow.Graphs | |||
| public override void OnEntry(MethodExecutionArgs args) | |||
| { | |||
| func_name = $"{args.Method.Name}_{Guid.NewGuid()}"; | |||
| // TODO: func_name can be cache in FullName + Args | |||
| func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}_{Guid.NewGuid()}"; | |||
| if (functions.ContainsKey(func_name)) | |||
| { | |||
| function = functions[func_name]; | |||
| if (args.Arguments[0] is Tensors tensor_inputs) | |||
| args.ReturnValue = ConvertReturnValue(function.Invoke(tensor_inputs)); | |||
| args.ReturnValue = ConvertReturnValue(function.FilteredCall(tensor_inputs)); | |||
| else | |||
| args.ReturnValue = ConvertReturnValue(function.Invoke(args.Arguments.Select(x => x as Tensor).ToArray())); | |||
| args.ReturnValue = ConvertReturnValue(function.FilteredCall(args.Arguments.Select(x => x as Tensor).ToArray())); | |||
| args.FlowBehavior = FlowBehavior.Return; | |||
| return; | |||
| } | |||
| @@ -62,14 +66,27 @@ namespace Tensorflow.Graphs | |||
| { | |||
| if (args.ReturnValue is Tensors outputs) | |||
| { | |||
| if (args.Arguments[0] is Tensors inputs) | |||
| function.ToGraph(inputs, outputs); | |||
| Tensors inputs = null; | |||
| outputs = mark_as_return(outputs); | |||
| if (args.Arguments[0] is Tensors inputs1) | |||
| inputs = inputs1; | |||
| else | |||
| function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), outputs); | |||
| inputs = args.Arguments.Select(x => x as Tensor).ToArray(); | |||
| inputs = inputs.Where(x => x.op.OpType == "Placeholder" | |||
| && x.op.name.StartsWith("inputs")).ToArray(); | |||
| function.ToGraph(inputs, outputs); | |||
| } | |||
| else | |||
| function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), args.ReturnValue as Tensor); | |||
| else if (args.ReturnValue is Tensor output) | |||
| { | |||
| var inputs = args.Arguments.Select(x => x as Tensor) | |||
| .Where(x => x.op.type == "Placeholder" && x.op.name.StartsWith("inputs")) | |||
| .ToArray(); | |||
| var outputs2 = array_ops.identity(output); | |||
| function.ToGraph(inputs, outputs2); | |||
| } | |||
| function.Exit(); | |||
| // cache function. | |||
| @@ -77,7 +94,7 @@ namespace Tensorflow.Graphs | |||
| functions[func_name] = function; | |||
| // run function | |||
| args.ReturnValue = ConvertReturnValue(function.Invoke(originalInputs)); | |||
| args.ReturnValue = ConvertReturnValue(function.FilteredCall(originalInputs)); | |||
| } | |||
| object ConvertReturnValue(Tensors tensors) | |||
| @@ -87,5 +104,20 @@ namespace Tensorflow.Graphs | |||
| else | |||
| return tensors; | |||
| } | |||
| /// <summary> | |||
| /// Acts like identity but marks the `Tensor` as a return value. | |||
| /// </summary> | |||
| /// <param name="tensors"></param> | |||
| /// <returns></returns> | |||
| public Tensors mark_as_return(Tensors tensors) | |||
| { | |||
| if (tensors == null) | |||
| return null; | |||
| var result = new Tensors(); | |||
| foreach (var tensor in tensors) | |||
| result.Add(array_ops.identity(tensor)); | |||
| return result; | |||
| } | |||
| } | |||
| } | |||
| @@ -925,7 +925,28 @@ namespace Tensorflow | |||
| public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null) | |||
| => gen_array_ops.slice(input, begin, size, name: name); | |||
| public static Tensor stack(object values, int axis = 0, string name = "stack") | |||
| public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name = null) | |||
| => tf.Context.RunInAutoMode2( | |||
| () => tf.OpDefLib._apply_op_helper("Slice", name, new | |||
| { | |||
| input, begin, size | |||
| }).output, | |||
| () => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Slice", name, | |||
| null, | |||
| input, begin, size).FirstOrDefault(), | |||
| (op) => | |||
| { | |||
| var attrs = new object[] | |||
| { | |||
| "T", op.get_attr<TF_DataType>("T"), | |||
| "Index", op.get_attr<int>("Index") | |||
| }; | |||
| tf.Runner.RecordGradient("Slice", op.inputs, attrs, op.outputs); | |||
| }, | |||
| new Tensors(input, begin, size)); | |||
| public static Tensor stack(object values, int axis = 0, string name = "stack") | |||
| { | |||
| if (axis == 0) | |||
| // If the input is a constant list, it can be converted to a constant op | |||
| @@ -238,18 +238,32 @@ namespace Tensorflow | |||
| "half_pixel_centers", half_pixel_centers).FirstOrDefault(), | |||
| images); | |||
| public static Tensor resize_nearest_neighbor_grad<Tsize>(Tensor grads, Tsize size, bool align_corners = false, | |||
| public static Tensor resize_nearest_neighbor_grad(Tensor grads, Tensor size, bool align_corners = false, | |||
| bool half_pixel_centers = false, string name = null) | |||
| { | |||
| var op = tf.OpDefLib._apply_op_helper("ResizeNearestNeighborGrad", name: name, args: new | |||
| { | |||
| grads, | |||
| size, | |||
| align_corners, | |||
| half_pixel_centers | |||
| }); | |||
| return op.output; | |||
| } | |||
| => tf.Context.RunInAutoMode2( | |||
| () => tf.OpDefLib._apply_op_helper("ResizeNearestNeighborGrad", name, new | |||
| { | |||
| grads, | |||
| size, | |||
| align_corners, | |||
| half_pixel_centers | |||
| }).output, | |||
| () => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "ResizeNearestNeighborGrad", name, | |||
| null, | |||
| grads, size, | |||
| "align_corners", align_corners, | |||
| "half_pixel_centers", half_pixel_centers).FirstOrDefault(), | |||
| (op) => | |||
| { | |||
| var attrs = new object[] | |||
| { | |||
| "T", op.get_attr<TF_DataType>("T"), | |||
| "align_corners", op.get_attr<bool>("align_corners"), | |||
| "half_pixel_centers", op.get_attr<bool>("half_pixel_centers") | |||
| }; | |||
| tf.Runner.RecordGradient("ResizeNearestNeighborGrad", op.inputs, attrs, op.outputs); | |||
| }, | |||
| new Tensors(grads, size)); | |||
| } | |||
| } | |||
| @@ -126,6 +126,16 @@ namespace Tensorflow | |||
| public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0, | |||
| string name = null) | |||
| { | |||
| if (tf.Context.executing_eagerly()) | |||
| { | |||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "RandomShuffle", name, | |||
| null, | |||
| value, seed, seed2); | |||
| return results[0]; | |||
| } | |||
| var _op = tf.OpDefLib._apply_op_helper("RandomShuffle", | |||
| name: name, | |||
| args: new { value, seed, seed2 }); | |||
| @@ -83,6 +83,7 @@ TensorFlow .NET v0.30 is focused on making more Keras API work including: | |||
| <ItemGroup> | |||
| <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" /> | |||
| <PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="5.0.1" /> | |||
| <PackageReference Include="NumSharp.Lite" Version="0.1.10" /> | |||
| <PackageReference Include="Protobuf.Text" Version="0.4.0" /> | |||
| <PackageReference Include="Serilog.Sinks.Console" Version="3.1.1" /> | |||
| @@ -57,6 +57,9 @@ namespace Tensorflow | |||
| public void Add(Tensor tensor) | |||
| => items.Add(tensor); | |||
| public void AddRange(Tensor[] tensors) | |||
| => items.AddRange(tensors); | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| { | |||
| throw new NotImplementedException(); | |||
| @@ -48,7 +48,7 @@ namespace Tensorflow | |||
| public tensorflow() | |||
| { | |||
| Logger = new LoggerConfiguration() | |||
| .MinimumLevel.Error() | |||
| .MinimumLevel.Warning() | |||
| .WriteTo.Console() | |||
| .CreateLogger(); | |||
| @@ -16,6 +16,7 @@ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Linq; | |||
| using System.Collections.Generic; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.Graphs; | |||
| @@ -197,7 +198,7 @@ namespace Tensorflow.Keras | |||
| } | |||
| if (outputs[0].op.type == "Placeholder" | |||
| || outputs[0].op.type == "StridedSlice") | |||
| return exec_graph.external_captures[0].numpy(); | |||
| return exec_graph.external_captures.Last().numpy(); | |||
| // Consolidate updates | |||
| exec_graph.as_default(); | |||
| @@ -0,0 +1,12 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| namespace Tensorflow.Keras.Engine | |||
| { | |||
| public interface ITensorFlowOpLayer | |||
| { | |||
| Layer GetOpLayer(TensorFlowOpLayerArgs args); | |||
| } | |||
| } | |||
| @@ -1,5 +1,4 @@ | |||
| using NumSharp; | |||
| using ShellProgressBar; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| @@ -88,15 +87,8 @@ namespace Tensorflow.Keras.Engine | |||
| { | |||
| stop_training = false; | |||
| _train_counter.assign(0); | |||
| var options = new ProgressBarOptions | |||
| { | |||
| ProgressCharacter = '.', | |||
| ProgressBarOnBottom = true | |||
| }; | |||
| foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||
| { | |||
| using var pbar = new ProgressBar(data_handler.Inferredsteps, "Training...", options); | |||
| // reset_metrics(); | |||
| // callbacks.on_epoch_begin(epoch) | |||
| // data_handler.catch_stop_iteration(); | |||
| @@ -105,7 +97,7 @@ namespace Tensorflow.Keras.Engine | |||
| // callbacks.on_train_batch_begin(step) | |||
| var results = step_function(iterator); | |||
| var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}")); | |||
| pbar.Tick($"[Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}]"); | |||
| Console.WriteLine($"[Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}]"); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,66 +0,0 @@ | |||
| using NumSharp; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Graphs; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.Engine | |||
| { | |||
| public class TensorFlowOpLayer : Layer | |||
| { | |||
| TensorFlowOpLayerArgs args; | |||
| Dictionary<int, NDArray> constants => args.Constants; | |||
| NodeDef node_def => args.NodeDef; | |||
| static string TF_OP_LAYER_NAME_PREFIX = "tf_op_layer_"; | |||
| public string OpType => node_def.Op; | |||
| public TensorFlowOpLayer(TensorFlowOpLayerArgs args) | |||
| : base(new LayerArgs | |||
| { | |||
| Name = TF_OP_LAYER_NAME_PREFIX + args.Name, | |||
| Trainable = args.Trainable, | |||
| DType = args.DType, | |||
| Autocast = false | |||
| }) | |||
| { | |||
| this.args = args; | |||
| built = true; | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| { | |||
| if (tf.Context.executing_eagerly()) | |||
| return _defun_call(inputs); | |||
| return MakOp(inputs); | |||
| } | |||
| [AutoGraph] | |||
| Tensors _defun_call(Tensors inputs) | |||
| => MakOp(inputs); | |||
| Tensors MakOp(Tensors inputs) | |||
| { | |||
| foreach (var (index, constant) in enumerate(constants)) | |||
| { | |||
| var value = constant_op.constant(constant, name: node_def.Input[index]); | |||
| var new_inputs = inputs.ToList(); | |||
| new_inputs.Insert(index, value); | |||
| inputs = new Tensors(new_inputs.ToArray()); | |||
| } | |||
| var graph = inputs.graph; | |||
| var (c_op, _) = ops._create_c_op(graph, node_def, inputs.ToArray(), new Operation[0]); | |||
| var op = graph._create_op_from_tf_operation(c_op); | |||
| op._control_flow_post_processing(); | |||
| // Record the gradient because custom-made ops don't go through the | |||
| // code-gen'd eager call path | |||
| var op_type = op.node_def.Op; | |||
| tf.Runner.RecordGradient(op_type, op.inputs._inputs, null, op.outputs); | |||
| return op.outputs; | |||
| } | |||
| } | |||
| } | |||
| @@ -1,4 +1,7 @@ | |||
| using System.Collections.Generic; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Reflection; | |||
| using System.Linq; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Datasets; | |||
| using Tensorflow.Keras.Engine; | |||
| @@ -1,4 +1,4 @@ | |||
| <Project Sdk="Microsoft.NET.Sdk"> | |||
| <Project Sdk="Microsoft.NET.Sdk"> | |||
| <PropertyGroup> | |||
| <TargetFramework>netstandard2.0</TargetFramework> | |||
| @@ -47,7 +47,6 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||
| <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" /> | |||
| <PackageReference Include="Newtonsoft.Json" Version="12.0.3" /> | |||
| <PackageReference Include="SharpZipLib" Version="1.3.1" /> | |||
| <PackageReference Include="ShellProgressBar" Version="5.0.0" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| @@ -18,6 +18,7 @@ using NumSharp; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Reflection; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using static Tensorflow.Binding; | |||
| @@ -151,7 +152,7 @@ namespace Tensorflow.Keras.Utils | |||
| // recursively | |||
| CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers); | |||
| Layer op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs | |||
| var op_layer = GetLayer<ITensorFlowOpLayer>(new TensorFlowOpLayerArgs | |||
| { | |||
| NodeDef = op.node_def, | |||
| Constants = constants, | |||
| @@ -164,6 +165,20 @@ namespace Tensorflow.Keras.Utils | |||
| } | |||
| } | |||
| static Layer GetLayer<T>(LayerArgs args) | |||
| { | |||
| Layer layer = default; | |||
| var assemble = Assembly.Load("TensorFlow.Keras.Layers"); | |||
| foreach (var type in assemble.GetTypes().Where(x => x.GetInterface(typeof(T).Name) != null)) | |||
| { | |||
| layer = (Layer)Activator.CreateInstance(type, new object[] { args }); | |||
| } | |||
| if (layer == null) | |||
| throw new NotImplementedException($"Can't find implementation for type {args.GetType().Name}"); | |||
| return layer; | |||
| } | |||
| // recusive | |||
| static bool uses_keras_history(Tensor op_input) | |||
| { | |||