| @@ -0,0 +1,15 @@ | |||
| using Newtonsoft.Json; | |||
| namespace Tensorflow.Keras.ArgsDefinition; | |||
| public class NormalizationArgs : PreprocessingLayerArgs | |||
| { | |||
| [JsonProperty("axis")] | |||
| public Axis? Axis { get; set; } | |||
| [JsonProperty("mean")] | |||
| public float? Mean { get; set; } | |||
| [JsonProperty("variance")] | |||
| public float? Variance { get; set; } | |||
| public bool Invert { get; set; } = false; | |||
| } | |||
| @@ -23,5 +23,6 @@ namespace Tensorflow.Keras | |||
| TensorShapeConfig BuildInputShape { get; } | |||
| TF_DataType DType { get; } | |||
| int count_params(); | |||
| void adapt(Tensor data, int? batch_size = null, int? steps = null); | |||
| } | |||
| } | |||
| @@ -156,6 +156,7 @@ namespace Tensorflow.Keras.Layers | |||
| IInitializer beta_initializer = null, | |||
| IInitializer gamma_initializer = null); | |||
| public ILayer Normalization(int? axis = -1, float? mean = null, float? variance = null, bool invert = false); | |||
| public ILayer LeakyReLU(float alpha = 0.3f); | |||
| public ILayer LSTM(int units, | |||
| @@ -9,6 +9,9 @@ namespace Tensorflow.NumPy | |||
| { | |||
| public static long GetSize(Shape shape) | |||
| { | |||
| if (shape.IsNull) | |||
| return 0; | |||
| // scalar | |||
| if (shape.ndim == 0) | |||
| return 1; | |||
| @@ -159,5 +159,10 @@ namespace Tensorflow | |||
| } | |||
| public Trackable GetTrackable() { throw new NotImplementedException(); } | |||
| public void adapt(Tensor data, int? batch_size = null, int? steps = null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| } | |||
| @@ -16,6 +16,7 @@ | |||
| using Serilog; | |||
| using Serilog.Core; | |||
| using System.Reflection; | |||
| using System.Threading; | |||
| using Tensorflow.Contexts; | |||
| using Tensorflow.Eager; | |||
| @@ -52,7 +53,29 @@ namespace Tensorflow | |||
| ThreadLocal<IEagerRunner> _runner = new ThreadLocal<IEagerRunner>(() => new EagerRunner()); | |||
| public IEagerRunner Runner => _runner.Value; | |||
| public IKerasApi keras { get; set; } | |||
| private IKerasApi _keras; | |||
| public IKerasApi keras | |||
| { | |||
| get | |||
| { | |||
| if (_keras != null) | |||
| { | |||
| return _keras; | |||
| } | |||
| var k = Assembly.Load("Tensorflow.Keras"); | |||
| var cls = k.GetTypes().FirstOrDefault(x => x.GetInterfaces().Contains(typeof(IKerasApi))); | |||
| if (cls != null) | |||
| { | |||
| _keras = Activator.CreateInstance(cls) as IKerasApi; | |||
| return _keras; | |||
| } | |||
| else | |||
| { | |||
| throw new Exception("Can't find keras library."); | |||
| } | |||
| } | |||
| } | |||
| public tensorflow() | |||
| { | |||
| @@ -344,5 +344,10 @@ namespace Tensorflow.Keras.Engine | |||
| public virtual IKerasConfig get_config() | |||
| => args; | |||
| public virtual void adapt(Tensor data, int? batch_size = null, int? steps = null) | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -20,10 +20,6 @@ namespace Tensorflow.Keras | |||
| { | |||
| private static KerasInterface _instance = null; | |||
| private static readonly object _lock = new object(); | |||
| private KerasInterface() | |||
| { | |||
| Tensorflow.Binding.tf.keras = this; | |||
| } | |||
| public static KerasInterface Instance | |||
| { | |||
| @@ -872,5 +872,14 @@ namespace Tensorflow.Keras.Layers | |||
| Sparse = sparse, | |||
| CountWeights = count_weights | |||
| }); | |||
| public ILayer Normalization(int? axis = -1, float? mean = null, float? variance = null, bool invert = false) | |||
| => new Normalization(new NormalizationArgs | |||
| { | |||
| Axis = axis, | |||
| Mean = mean, | |||
| Variance = variance, | |||
| Invert = invert | |||
| }); | |||
| } | |||
| } | |||
| @@ -0,0 +1,173 @@ | |||
| /***************************************************************************** | |||
| Copyright 2023 Haiping Chen. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| public class Normalization : PreprocessingLayer | |||
| { | |||
| NormalizationArgs _args; | |||
| int[] axis; | |||
| int[] _reduce_axis; | |||
| IVariableV1 adapt_mean, adapt_variance, count; | |||
| Tensor mean, variance; | |||
| Shape _broadcast_shape; | |||
| float? input_mean, input_variance; | |||
| TF_DataType compute_dtype = tf.float32; | |||
| public Normalization(NormalizationArgs args) : base(args) | |||
| { | |||
| _args = args; | |||
| if (args.Axis == null) | |||
| { | |||
| axis = new int[0]; | |||
| } | |||
| else | |||
| { | |||
| axis = args.Axis.axis; | |||
| } | |||
| input_mean = args.Mean; | |||
| input_variance = args.Variance; | |||
| } | |||
| public override void build(Shape input_shape) | |||
| { | |||
| base.build(input_shape); | |||
| var ndim = input_shape.ndim; | |||
| foreach (var (idx, x) in enumerate(axis)) | |||
| if (x < 0) | |||
| axis[idx] = ndim + x; | |||
| var _keep_axis = axis.Select(d => d >= 0 ? d : d + ndim).ToArray(); | |||
| _reduce_axis = range(ndim).Where(d => !_keep_axis.Contains(d)).ToArray(); | |||
| var _reduce_axis_mask = range(ndim).Select(d => _keep_axis.Contains(d) ? 0 : 1).ToArray(); | |||
| // Broadcast any reduced axes. | |||
| _broadcast_shape = new Shape(range(ndim).Select(d => _keep_axis.Contains(d) ? input_shape.dims[d] : 1).ToArray()); | |||
| var mean_and_var_shape = _keep_axis.Select(d => input_shape.dims[d]).ToArray(); | |||
| var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType; | |||
| var param_shape = input_shape; | |||
| if(input_mean == null) | |||
| { | |||
| adapt_mean = add_weight("mean", | |||
| mean_and_var_shape, | |||
| dtype: tf.float32, | |||
| initializer: tf.zeros_initializer, | |||
| trainable: false); | |||
| adapt_variance = add_weight("variance", | |||
| mean_and_var_shape, | |||
| dtype: tf.float32, | |||
| initializer: tf.ones_initializer, | |||
| trainable: false); | |||
| count = add_weight("count", | |||
| Shape.Scalar, | |||
| dtype: tf.int64, | |||
| initializer: tf.zeros_initializer, | |||
| trainable: false); | |||
| finalize_state(); | |||
| } | |||
| else | |||
| { | |||
| mean = input_mean * np.ones(mean_and_var_shape); | |||
| variance = input_variance * np.ones(mean_and_var_shape); | |||
| mean = tf.reshape(mean, _broadcast_shape); | |||
| variance = tf.reshape(variance, _broadcast_shape); | |||
| mean = tf.cast(mean, compute_dtype); | |||
| variance = tf.cast(variance, compute_dtype); | |||
| } | |||
| } | |||
| public override void reset_state() | |||
| { | |||
| if (input_mean != null && !built) | |||
| { | |||
| return; | |||
| } | |||
| adapt_mean.assign(tf.zeros_like(adapt_mean.AsTensor())); | |||
| adapt_variance.assign(tf.ones_like(adapt_variance.AsTensor())); | |||
| count.assign(tf.zeros_like(count.AsTensor())); | |||
| } | |||
| public override void finalize_state() | |||
| { | |||
| if (input_mean != null && !built) | |||
| { | |||
| return; | |||
| } | |||
| mean = tf.reshape(adapt_mean.AsTensor(), _broadcast_shape); | |||
| variance = tf.reshape(adapt_variance.AsTensor(), _broadcast_shape); | |||
| } | |||
| public override void update_state(Tensor data) | |||
| { | |||
| data = tf.cast(data, adapt_mean.dtype); | |||
| var (batch_mean, batch_variance) = tf.nn.moments(data, axes: _reduce_axis); | |||
| var batch_shape = tf.shape(data, out_type: count.dtype); | |||
| var batch_count = constant_op.constant(1L); | |||
| if (_reduce_axis != null) | |||
| { | |||
| var batch_reduce_shape = tf.gather(batch_shape, constant_op.constant(_reduce_axis)); | |||
| batch_count = tf.reduce_prod(batch_reduce_shape); | |||
| } | |||
| var total_count = batch_count + count.AsTensor(); | |||
| var batch_weight = tf.cast(batch_count, dtype: compute_dtype) / tf.cast( | |||
| total_count, dtype: compute_dtype); | |||
| var existing_weight = 1.0 - batch_weight; | |||
| var total_mean = adapt_mean.AsTensor() * existing_weight + batch_mean * batch_weight; | |||
| var total_variance = ( | |||
| adapt_variance.AsTensor() + tf.square(adapt_mean.AsTensor() - total_mean) | |||
| ) * existing_weight + ( | |||
| batch_variance + tf.square(batch_mean - total_mean) | |||
| ) * batch_weight; | |||
| adapt_mean.assign(total_mean); | |||
| adapt_variance.assign(total_variance); | |||
| count.assign(total_count); | |||
| } | |||
| public override Shape ComputeOutputShape(Shape input_shape) | |||
| { | |||
| return input_shape; | |||
| } | |||
| public override void adapt(Tensor data, int? batch_size = null, int? steps = null) | |||
| { | |||
| base.adapt(data, batch_size: batch_size, steps: steps); | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| { | |||
| if (_args.Invert) | |||
| { | |||
| return mean + ( | |||
| inputs * tf.maximum(tf.sqrt(variance), keras.backend.epsilon()) | |||
| ); | |||
| } | |||
| else | |||
| { | |||
| return (inputs - mean) / tf.maximum( | |||
| tf.sqrt(variance), keras.backend.epsilon()); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -3,14 +3,95 @@ using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Engine.DataAdapters; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| public class PreprocessingLayer : Layer | |||
| { | |||
| bool _is_compiled; | |||
| bool _is_adapted; | |||
| IVariableV1 _steps_per_execution; | |||
| PreprocessingLayerArgs _args; | |||
| public PreprocessingLayer(PreprocessingLayerArgs args) : base(args) | |||
| { | |||
| _args = args; | |||
| } | |||
| public override void adapt(Tensor data, int? batch_size = null, int? steps = null) | |||
| { | |||
| if (!_is_compiled) | |||
| { | |||
| compile(); | |||
| } | |||
| if (built) | |||
| { | |||
| reset_state(); | |||
| } | |||
| var data_handler = new DataHandler(new DataHandlerArgs | |||
| { | |||
| X = new Tensors(data), | |||
| BatchSize = _args.BatchSize, | |||
| Epochs = 1, | |||
| StepsPerExecution = _steps_per_execution | |||
| }); | |||
| foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||
| { | |||
| foreach (var _ in data_handler.steps()) | |||
| { | |||
| run_step(iterator); | |||
| } | |||
| } | |||
| finalize_state(); | |||
| _is_adapted = true; | |||
| } | |||
| private void run_step(OwnedIterator iterator) | |||
| { | |||
| var data = iterator.next(); | |||
| _adapt_maybe_build(data[0]); | |||
| update_state(data[0]); | |||
| } | |||
| public virtual void reset_state() | |||
| { | |||
| } | |||
| public virtual void finalize_state() | |||
| { | |||
| } | |||
| public virtual void update_state(Tensor data) | |||
| { | |||
| } | |||
| private void _adapt_maybe_build(Tensor data) | |||
| { | |||
| if (!built) | |||
| { | |||
| var data_shape = data.shape; | |||
| var data_shape_nones = Enumerable.Range(0, data.ndim).Select(x => -1).ToArray(); | |||
| _args.BatchInputShape = BatchInputShape ?? new Shape(data_shape_nones); | |||
| build(data_shape); | |||
| built = true; | |||
| } | |||
| } | |||
| public void compile(bool run_eagerly = false, int steps_per_execution = 1) | |||
| { | |||
| _steps_per_execution = tf.Variable( | |||
| steps_per_execution, | |||
| dtype: tf.int64, | |||
| aggregation: VariableAggregation.OnlyFirstReplica | |||
| ); | |||
| _is_compiled = true; | |||
| } | |||
| } | |||
| } | |||
| @@ -1,5 +1,6 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using Tensorflow; | |||
| using Tensorflow.Keras; | |||
| using static Tensorflow.Binding; | |||
| @@ -177,6 +177,55 @@ namespace TensorFlowNET.Keras.UnitTest | |||
| Assert.IsTrue(output[0].numpy().Equals(new[] { -0.99998f, 0.99998f })); | |||
| } | |||
| /// <summary> | |||
| /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Normalization | |||
| /// </summary> | |||
| [TestMethod] | |||
| public void Normalization() | |||
| { | |||
| // Calculate a global mean and variance by analyzing the dataset in adapt(). | |||
| var adapt_data = np.array(new[] { 1f, 2f, 3f, 4f, 5f }); | |||
| var input_data = np.array(new[] { 1f, 2f, 3f }); | |||
| var layer = tf.keras.layers.Normalization(axis: null); | |||
| layer.adapt(adapt_data); | |||
| var x = layer.Apply(input_data); | |||
| Assert.AreEqual(x.numpy(), new[] { -1.4142135f, -0.70710677f, 0f }); | |||
| // Calculate a mean and variance for each index on the last axis. | |||
| adapt_data = np.array(new[,] | |||
| { | |||
| { 0, 7, 4 }, | |||
| { 2, 9, 6 }, | |||
| { 0, 7, 4 }, | |||
| { 2, 9, 6 } | |||
| }, dtype: tf.float32); | |||
| input_data = np.array(new[,] { { 0, 7, 4 } }, dtype: tf.float32); | |||
| layer = tf.keras.layers.Normalization(axis: -1); | |||
| layer.adapt(adapt_data); | |||
| x = layer.Apply(input_data); | |||
| Equal(x.numpy().ToArray<float>(), new[] { -1f, -1f, -1f }); | |||
| // Pass the mean and variance directly. | |||
| input_data = np.array(new[,] { { 1f }, { 2f }, { 3f } }, dtype: tf.float32); | |||
| layer = tf.keras.layers.Normalization(mean: 3f, variance: 2f); | |||
| x = layer.Apply(input_data); | |||
| Equal(x.numpy().ToArray<float>(), new[] { -1.4142135f, -0.70710677f, 0f }); | |||
| // Use the layer to de-normalize inputs (after adapting the layer). | |||
| adapt_data = np.array(new[,] | |||
| { | |||
| { 0, 7, 4 }, | |||
| { 2, 9, 6 }, | |||
| { 0, 7, 4 }, | |||
| { 2, 9, 6 } | |||
| }, dtype: tf.float32); | |||
| input_data = np.array(new[,] { { 1, 2, 3 } }, dtype: tf.float32); | |||
| layer = tf.keras.layers.Normalization(axis: -1, invert: true); | |||
| layer.adapt(adapt_data); | |||
| x = layer.Apply(input_data); | |||
| Equal(x.numpy().ToArray<float>(), new[] { -2f, -10f, -8f }); | |||
| } | |||
| /// <summary> | |||
| /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/CategoryEncoding | |||
| /// </summary> | |||