diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs index 4d2c2dec..bf5eba19 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs @@ -1,6 +1,5 @@ using System; using System.Linq; -using Microsoft.Extensions.Logging; using Tensorflow.Gradients; using static Tensorflow.Binding; using static Tensorflow.tensorflow; @@ -39,7 +38,7 @@ namespace Tensorflow.Eager }*/ } - tf.Logger.LogDebug($"RecordGradient: should_record={should_record}, op_name={op_name}"); + tf.Logger.Debug($"RecordGradient: should_record={should_record}, op_name={op_name}"); if (!should_record) return should_record; Tensor[] op_outputs; diff --git a/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs b/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs index 299be4a7..c39ec73f 100644 --- a/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs +++ b/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs @@ -1,6 +1,5 @@ using System; using System.Collections.Generic; -using Microsoft.Extensions.Logging; using Tensorflow.Util; using static Tensorflow.tensorflow; using static Tensorflow.Binding; @@ -36,7 +35,7 @@ namespace Tensorflow.Gradients foreach (var o in output_tensors) { tensor_tape_[o.GetID()] = op_id; - tf.Logger.LogDebug($"RecordOperation: tensor_tape_[{o.GetID()}] = {op_id}"); + tf.Logger.Debug($"RecordOperation: tensor_tape_[{o.GetID()}] = {op_id}"); tensor_usage_[o.GetID()] = 1; tensors.Add(o); } diff --git a/src/TensorFlowNET.Core/Gradients/Tape.cs b/src/TensorFlowNET.Core/Gradients/Tape.cs index d43e4ee3..ddfa590e 100644 --- a/src/TensorFlowNET.Core/Gradients/Tape.cs +++ b/src/TensorFlowNET.Core/Gradients/Tape.cs @@ -1,7 +1,6 @@ using System; using System.Collections.Generic; using Tensorflow.Util; -using Microsoft.Extensions.Logging; using static Tensorflow.Binding; using static Tensorflow.tensorflow; @@ -44,7 +43,7 @@ namespace Tensorflow.Gradients if (!CouldBackprop()) return; - tf.Logger.LogDebug($"Watch tensor_id={tensor_id}"); + tf.Logger.Debug($"Watch tensor_id={tensor_id}"); tensor_tape_.emplace(tensor_id, -1); } @@ -56,7 +55,7 @@ namespace Tensorflow.Gradients { if (IsDtypeTrainable(dtypes[i])) { - tf.Logger.LogDebug($"tape.h->ShouldRecord: should_record = true, tensor_tape_.size()={tensor_tape_.Count}, tensor_ids[{i}]={tensor_ids[i]}"); + tf.Logger.Debug($"tape.h->ShouldRecord: should_record = true, tensor_tape_.size()={tensor_tape_.Count}, tensor_ids[{i}]={tensor_ids[i]}"); return true; } } diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs index 5176d3ca..6422c546 100644 --- a/src/TensorFlowNET.Core/Gradients/array_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs @@ -16,6 +16,7 @@ using System.Collections.Generic; using System.Linq; +using Tensorflow.Eager; using Tensorflow.Framework; using static Tensorflow.Binding; @@ -82,7 +83,14 @@ namespace Tensorflow.Gradients .ToArray(); var out_grads = new List(); - if (constant_op.is_constant(concat_dim)) + if(concat_dim is EagerTensor) + { + var non_neg_concat_dim = (int)concat_dim % input_values[0].rank; + var sizes = input_values.Select(x => x.shape[non_neg_concat_dim]).ToArray(); + var sizes_tensor = constant_op.constant(sizes); + out_grads = gen_array_ops.split_v(grad, sizes_tensor, sizes[0], non_neg_concat_dim).ToList(); + } + else if (constant_op.is_constant(concat_dim)) { /*If concat_dim is a constant defined in a different context, then we duplicate it in the current context to avoid passing it @@ -97,33 +105,33 @@ namespace Tensorflow.Gradients var value = tensor_util.constant_value(concat_dim); concat_dim = constant_op.constant(value: value, dtype: concat_dim.dtype); } - } - // Using mod here for convenience since concat_dim is already verified - // in concat implementation to be within the allowed [-rank, rank) range. - var non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0]); + // Using mod here for convenience since concat_dim is already verified + // in concat implementation to be within the allowed [-rank, rank) range. + var non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0]); - // Get the inputs' tensor shapes - var sizes = _ExtractInputShapes(input_values); + // Get the inputs' tensor shapes + var sizes = _ExtractInputShapes(input_values); - /* The magic number of 16 was found through benchmarking a range of sizes - on CPUs and a Maxwell TitanX. A speedup was seen in a large majority of - cases when switching implementations at N=16, but it is possible that - there will be a small number of performance regressions.*/ - if (len(sizes) > 16) - { - // extract the size of each input along the concat dimension - var slice = array_ops.slice(array_ops.stack(sizes, axis: 1), - new Tensor[] { non_neg_concat_dim, tf.constant(0) }, - new Tensor[] { tf.constant(1), tf.constant(-1) }); - var squeeze_sizes = array_ops.squeeze(slice); - out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_split: (int)non_neg_concat_dim).ToList(); - } - else - { - var offset = gen_array_ops.concat_offset(non_neg_concat_dim, sizes); - foreach (var (begin, size) in zip(offset, sizes)) - out_grads.Add(gen_array_ops.slice(grad, begin, size)); + /* The magic number of 16 was found through benchmarking a range of sizes + on CPUs and a Maxwell TitanX. A speedup was seen in a large majority of + cases when switching implementations at N=16, but it is possible that + there will be a small number of performance regressions.*/ + if (len(sizes) > 16) + { + // extract the size of each input along the concat dimension + var slice = array_ops.slice(array_ops.stack(sizes, axis: 1), + new Tensor[] { non_neg_concat_dim, tf.constant(0) }, + new Tensor[] { tf.constant(1), tf.constant(-1) }); + var squeeze_sizes = array_ops.squeeze(slice); + out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_split: (int)non_neg_concat_dim).ToList(); + } + else + { + var offset = gen_array_ops.concat_offset(non_neg_concat_dim, sizes); + foreach (var (begin, size) in zip(offset, sizes)) + out_grads.Add(gen_array_ops.slice(grad, begin, size)); + } } return (end_value_index <= dim_index ? diff --git a/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs b/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs index b777f29a..1914f61d 100644 --- a/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs +++ b/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs @@ -56,6 +56,8 @@ namespace Tensorflow.Graphs public override void OnExit(MethodExecutionArgs args) { + var returnValue = mark_as_return(args.ReturnValue as Tensors); + var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); if (args.ReturnValue is Tensors outputs) @@ -102,5 +104,10 @@ namespace Tensorflow.Graphs // run function args.ReturnValue = function(originalInputs); } + + Tensor mark_as_return(Tensor tensor) + { + return array_ops.identity(tensor); + } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/LeakyReLUArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/LeakyReLuArgs.cs similarity index 85% rename from src/TensorFlowNET.Core/Keras/ArgsDefinition/LeakyReLUArgs.cs rename to src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/LeakyReLuArgs.cs index c62d7a12..6bdb294c 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/LeakyReLUArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/LeakyReLuArgs.cs @@ -4,7 +4,7 @@ using System.Text; namespace Tensorflow.Keras.ArgsDefinition { - public class LeakyReLUArgs : LayerArgs + public class LeakyReLuArgs : LayerArgs { /// /// Negative slope coefficient. diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Conv2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/Conv2DArgs.cs similarity index 100% rename from src/TensorFlowNET.Core/Keras/ArgsDefinition/Conv2DArgs.cs rename to src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/Conv2DArgs.cs diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/ConvolutionalArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs similarity index 100% rename from src/TensorFlowNET.Core/Keras/ArgsDefinition/ConvolutionalArgs.cs rename to src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DenseArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs similarity index 100% rename from src/TensorFlowNET.Core/Keras/ArgsDefinition/DenseArgs.cs rename to src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/EmbeddingArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EmbeddingArgs.cs similarity index 100% rename from src/TensorFlowNET.Core/Keras/ArgsDefinition/EmbeddingArgs.cs rename to src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EmbeddingArgs.cs diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/InputLayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs similarity index 100% rename from src/TensorFlowNET.Core/Keras/ArgsDefinition/InputLayerArgs.cs rename to src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/MergeArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs similarity index 100% rename from src/TensorFlowNET.Core/Keras/ArgsDefinition/MergeArgs.cs rename to src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/BatchNormalizationArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/BatchNormalizationArgs.cs similarity index 100% rename from src/TensorFlowNET.Core/Keras/ArgsDefinition/BatchNormalizationArgs.cs rename to src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/BatchNormalizationArgs.cs diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/MaxPooling2D.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/MaxPooling2D.cs similarity index 100% rename from src/TensorFlowNET.Core/Keras/ArgsDefinition/MaxPooling2D.cs rename to src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/MaxPooling2D.cs diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling2DArgs.cs similarity index 100% rename from src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling2DArgs.cs rename to src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling2DArgs.cs diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DropoutArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Regularization/DropoutArgs.cs similarity index 100% rename from src/TensorFlowNET.Core/Keras/ArgsDefinition/DropoutArgs.cs rename to src/TensorFlowNET.Core/Keras/ArgsDefinition/Regularization/DropoutArgs.cs diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/RescalingArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rescaling/RescalingArgs.cs similarity index 100% rename from src/TensorFlowNET.Core/Keras/ArgsDefinition/RescalingArgs.cs rename to src/TensorFlowNET.Core/Keras/ArgsDefinition/Rescaling/RescalingArgs.cs diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/FlattenArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs similarity index 100% rename from src/TensorFlowNET.Core/Keras/ArgsDefinition/FlattenArgs.cs rename to src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs index ffe20847..1b8aa8d1 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs @@ -531,6 +531,17 @@ namespace Tensorflow.Operations public static Tensor leaky_relu_grad(Tensor gradients, Tensor features, float alpha = 0.2f, string name = null) { + if (tf.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "LeakyReluGrad", name, + null, + gradients, features, + "alpha", alpha); + + return results[0]; + } + var _op = tf.OpDefLib._apply_op_helper("LeakyReluGrad", name: name, args: new { gradients, diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index 5e0f83e6..72439100 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -842,7 +842,7 @@ namespace Tensorflow // Restore shape information where possible. if (!tf.Context.executing_eagerly()) { - var paddings_constant = tensor_util.constant_value(result.op.inputs[1], partial: true); + var paddings_constant = tensor_util.constant_value(paddings); var input_shape = result.op.inputs[0].TensorShape; if (input_shape.ndim > -1 && !result.TensorShape.is_fully_defined() && diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index a5410185..99be2efe 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -496,6 +496,24 @@ namespace Tensorflow return _op.outputs[0]; } + public static Tensor[] split_v(Tensor value, Tensor size_splits, + int axis, int num_split, string name = null) + { + if (tf.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "SplitV", name, + null, + value, size_splits, axis, + "num_split", num_split); + + return results; + } + + var _op = tf.OpDefLib._apply_op_helper("SplitV", name, new { split_dim = axis, value, num_split }); + return _op.outputs; + } + public static Tensor tile(Tensor input, T multiples, string name = null) => tf.Context.RunInAutoMode(() => tf.OpDefLib._apply_op_helper("Tile", name, new { input, multiples }).output, () diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj index f6442c20..2b1d7b60 100644 --- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj +++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj @@ -84,10 +84,8 @@ TensorFlow .NET v0.30 is focused on making more Keras API work including: - - - + diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index 903ec3e6..2f91c0d3 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -14,9 +14,9 @@ limitations under the License. ******************************************************************************/ -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; using System.Collections.Generic; +using Serilog; +using Serilog.Core; using Tensorflow.Contexts; using Tensorflow.Eager; using Tensorflow.Gradients; @@ -43,17 +43,14 @@ namespace Tensorflow public OpDefLibrary OpDefLib; public Context Context; public IEagerRunner Runner; - public ILogger Logger; - ServiceProvider serviceProvider; + public Logger Logger; public tensorflow() { - serviceProvider = new ServiceCollection() - .AddLogging(cfg => cfg.AddConsole()) - .Configure(cfg => cfg.MinLevel = LogLevel.Warning) - .BuildServiceProvider(); - - Logger = serviceProvider.GetService>(); + Logger = new LoggerConfiguration() + .MinimumLevel.Error() + .WriteTo.Console() + .CreateLogger(); Status = new Status(); Context = new Context(new ContextOptions(), Status); diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs index f0f2ab2b..ac7f084c 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.cs @@ -3,7 +3,6 @@ using System.Collections.Generic; using System.Linq; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Utils; -using Microsoft.Extensions.Logging; using static Tensorflow.Binding; namespace Tensorflow.Keras.Engine @@ -336,7 +335,7 @@ namespace Tensorflow.Keras.Engine var layer_inputs = node.MapArguments(tensor_dict); - tf.Logger.LogDebug($"{node.Layer}: {node.Layer.Name}"); + tf.Logger.Debug($"{node.Layer}: {node.Layer.Name}"); var outputs = node.Layer.Apply(layer_inputs, is_training: training); // Update tensor_dict for next input diff --git a/src/TensorFlowNET.Keras/Layers/LeakyReLU.cs b/src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs similarity index 82% rename from src/TensorFlowNET.Keras/Layers/LeakyReLU.cs rename to src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs index 9693c466..625e81d4 100644 --- a/src/TensorFlowNET.Keras/Layers/LeakyReLU.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs @@ -10,11 +10,11 @@ namespace Tensorflow.Keras.Layers /// /// Leaky version of a Rectified Linear Unit. /// - public class LeakyReLU : Layer + public class LeakyReLu : Layer { - LeakyReLUArgs args; + LeakyReLuArgs args; float alpha => args.Alpha; - public LeakyReLU(LeakyReLUArgs args) : base(args) + public LeakyReLu(LeakyReLuArgs args) : base(args) { this.args = args; } diff --git a/src/TensorFlowNET.Keras/Layers/Conv2D.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2D.cs similarity index 100% rename from src/TensorFlowNET.Keras/Layers/Conv2D.cs rename to src/TensorFlowNET.Keras/Layers/Convolution/Conv2D.cs diff --git a/src/TensorFlowNET.Keras/Layers/Convolutional.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs similarity index 100% rename from src/TensorFlowNET.Keras/Layers/Convolutional.cs rename to src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs diff --git a/src/TensorFlowNET.Keras/Layers/Dense.cs b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs similarity index 100% rename from src/TensorFlowNET.Keras/Layers/Dense.cs rename to src/TensorFlowNET.Keras/Layers/Core/Dense.cs diff --git a/src/TensorFlowNET.Keras/Layers/Embedding.cs b/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs similarity index 100% rename from src/TensorFlowNET.Keras/Layers/Embedding.cs rename to src/TensorFlowNET.Keras/Layers/Core/Embedding.cs diff --git a/src/TensorFlowNET.Keras/Layers/InputLayer.cs b/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs similarity index 100% rename from src/TensorFlowNET.Keras/Layers/InputLayer.cs rename to src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index 5a41c76e..03125e03 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -321,7 +321,7 @@ namespace Tensorflow.Keras.Layers /// Negative slope coefficient. /// public Layer LeakyReLU(float alpha = 0.3f) - => new LeakyReLU(new LeakyReLUArgs + => new LeakyReLu(new LeakyReLuArgs { Alpha = alpha }); diff --git a/src/TensorFlowNET.Keras/Layers/Add.cs b/src/TensorFlowNET.Keras/Layers/Merging/Add.cs similarity index 100% rename from src/TensorFlowNET.Keras/Layers/Add.cs rename to src/TensorFlowNET.Keras/Layers/Merging/Add.cs diff --git a/src/TensorFlowNET.Keras/Layers/Merge.cs b/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs similarity index 100% rename from src/TensorFlowNET.Keras/Layers/Merge.cs rename to src/TensorFlowNET.Keras/Layers/Merging/Merge.cs diff --git a/src/TensorFlowNET.Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs similarity index 100% rename from src/TensorFlowNET.Keras/Layers/BatchNormalization.cs rename to src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs diff --git a/src/TensorFlowNET.Keras/Layers/GlobalAveragePooling2D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs similarity index 100% rename from src/TensorFlowNET.Keras/Layers/GlobalAveragePooling2D.cs rename to src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs diff --git a/src/TensorFlowNET.Keras/Layers/GlobalPooling2D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalPooling2D.cs similarity index 100% rename from src/TensorFlowNET.Keras/Layers/GlobalPooling2D.cs rename to src/TensorFlowNET.Keras/Layers/Pooling/GlobalPooling2D.cs diff --git a/src/TensorFlowNET.Keras/Layers/MaxPooling2D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/MaxPooling2D.cs similarity index 100% rename from src/TensorFlowNET.Keras/Layers/MaxPooling2D.cs rename to src/TensorFlowNET.Keras/Layers/Pooling/MaxPooling2D.cs diff --git a/src/TensorFlowNET.Keras/Layers/Pooling2D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs similarity index 100% rename from src/TensorFlowNET.Keras/Layers/Pooling2D.cs rename to src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs diff --git a/src/TensorFlowNET.Keras/Layers/Dropout.cs b/src/TensorFlowNET.Keras/Layers/Regularization/Dropout.cs similarity index 100% rename from src/TensorFlowNET.Keras/Layers/Dropout.cs rename to src/TensorFlowNET.Keras/Layers/Regularization/Dropout.cs diff --git a/src/TensorFlowNET.Keras/Layers/Rescaling.cs b/src/TensorFlowNET.Keras/Layers/Rescaling/Rescaling.cs similarity index 100% rename from src/TensorFlowNET.Keras/Layers/Rescaling.cs rename to src/TensorFlowNET.Keras/Layers/Rescaling/Rescaling.cs diff --git a/src/TensorFlowNET.Keras/Engine/Flatten.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs similarity index 93% rename from src/TensorFlowNET.Keras/Engine/Flatten.cs rename to src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs index 53ce2b63..316cab8c 100644 --- a/src/TensorFlowNET.Keras/Engine/Flatten.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs @@ -1,9 +1,10 @@ using System; using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; using Tensorflow.Keras.Utils; using static Tensorflow.Binding; -namespace Tensorflow.Keras.Engine +namespace Tensorflow.Keras.Layers { public class Flatten : Layer { diff --git a/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs index 45ec506f..0510a25c 100644 --- a/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs @@ -151,12 +151,23 @@ namespace Tensorflow.Keras.Utils // recursively CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers); - var op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs + Layer op_layer = null; + /*var op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs { NodeDef = op.node_def, Constants = constants, Name = op.name - }); + });*/ + op_layer = op.type switch + { + // "AddV2" => keras.layers.Add(), + _ => new TensorFlowOpLayer(new TensorFlowOpLayerArgs + { + NodeDef = op.node_def, + Constants = constants, + Name = op.name + }) + }; created_layers.Add(op_layer); op_layer.SetConnectivityMetadata(layer_inputs, op.outputs); processed_ops.Add(op);