| @@ -26,7 +26,7 @@ Related merged [commits](https://github.com/SciSharp/TensorFlow.NET/commit/854a5 | |||
| #### Download pre-build package | |||
| [Mac OSX CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-2.4.0.tar.gz), [Linux CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.4.0.tar.gz), [Linux GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-2.4.0.tar.gz), [Windows CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-2.4.0.tar.gz), [Windows GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-windows-x86_64-2.4.0.zip) | |||
| [Mac OSX CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-2.10.0.tar.gz), [Linux CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.10.0.tar.gz), [Linux GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-2.10.0.tar.gz), [Windows CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-2.10.0.zip), [Windows GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-windows-x86_64-2.10.0.zip) | |||
| @@ -35,6 +35,6 @@ Related merged [commits](https://github.com/SciSharp/TensorFlow.NET/commit/854a5 | |||
| On Windows, the tar command does not support extracting archives with symlinks. So when `dotnet pack` runs on Windows it will only package the Windows binaries. | |||
| 1. Run `dotnet pack SciSharp.TensorFlow.Redist.nupkgproj` under `src/SciSharp.TensorFlow.Redist` directory in Linux. | |||
| 2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.4.0.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600` | |||
| 2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.10.0.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600` | |||
| @@ -10,6 +10,9 @@ namespace Tensorflow | |||
| var diag = new Diagnostician(); | |||
| // diag.Diagnose(@"D:\memory.txt"); | |||
| var rnn = new SimpleRnnTest(); | |||
| rnn.Run(); | |||
| // this class is used explor new features. | |||
| var exploring = new Exploring(); | |||
| // exploring.Run(); | |||
| @@ -0,0 +1,31 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras; | |||
| using Tensorflow.NumPy; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.KerasApi; | |||
| namespace Tensorflow | |||
| { | |||
| public class SimpleRnnTest | |||
| { | |||
| public void Run() | |||
| { | |||
| tf.keras = new KerasInterface(); | |||
| var inputs = np.random.random((32, 10, 8)).astype(np.float32); | |||
| var simple_rnn = tf.keras.layers.SimpleRNN(4); | |||
| var output = simple_rnn.Apply(inputs); // The output has shape `[32, 4]`. | |||
| if (output.shape == (32, 4)) | |||
| { | |||
| } | |||
| /*simple_rnn = tf.keras.layers.SimpleRNN( | |||
| 4, return_sequences = True, return_state = True) | |||
| # whole_sequence_output has shape `[32, 10, 4]`. | |||
| # final_state has shape `[32, 4]`. | |||
| whole_sequence_output, final_state = simple_rnn(inputs)*/ | |||
| } | |||
| } | |||
| } | |||
| @@ -6,7 +6,7 @@ | |||
| <RootNamespace>Tensorflow</RootNamespace> | |||
| <AssemblyName>Tensorflow</AssemblyName> | |||
| <Platforms>AnyCPU;x64</Platforms> | |||
| <LangVersion>9.0</LangVersion> | |||
| <LangVersion>11.0</LangVersion> | |||
| </PropertyGroup> | |||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||
| @@ -20,7 +20,7 @@ | |||
| </PropertyGroup> | |||
| <ItemGroup> | |||
| <PackageReference Include="SciSharp.TensorFlow.Redist-Windows-GPU" Version="2.7.0" /> | |||
| <PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.10.0" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| @@ -1,22 +0,0 @@ | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class LSTMArgs : RNNArgs | |||
| { | |||
| public int Units { get; set; } | |||
| public Activation Activation { get; set; } | |||
| public Activation RecurrentActivation { get; set; } | |||
| public IInitializer KernelInitializer { get; set; } | |||
| public IInitializer RecurrentInitializer { get; set; } | |||
| public IInitializer BiasInitializer { get; set; } | |||
| public bool UnitForgetBias { get; set; } | |||
| public float Dropout { get; set; } | |||
| public float RecurrentDropout { get; set; } | |||
| public int Implementation { get; set; } | |||
| public bool ReturnSequences { get; set; } | |||
| public bool ReturnState { get; set; } | |||
| public bool GoBackwards { get; set; } | |||
| public bool Stateful { get; set; } | |||
| public bool TimeMajor { get; set; } | |||
| public bool Unroll { get; set; } | |||
| } | |||
| } | |||
| @@ -0,0 +1,12 @@ | |||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
| namespace Tensorflow.Keras.ArgsDefinition.Lstm | |||
| { | |||
| public class LSTMArgs : RNNArgs | |||
| { | |||
| public bool UnitForgetBias { get; set; } | |||
| public float Dropout { get; set; } | |||
| public float RecurrentDropout { get; set; } | |||
| public int Implementation { get; set; } | |||
| } | |||
| } | |||
| @@ -1,4 +1,4 @@ | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| namespace Tensorflow.Keras.ArgsDefinition.Lstm | |||
| { | |||
| public class LSTMCellArgs : LayerArgs | |||
| { | |||
| @@ -1,21 +0,0 @@ | |||
| using System.Collections.Generic; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class RNNArgs : LayerArgs | |||
| { | |||
| public interface IRnnArgCell : ILayer | |||
| { | |||
| object state_size { get; } | |||
| } | |||
| public IRnnArgCell Cell { get; set; } = null; | |||
| public bool ReturnSequences { get; set; } = false; | |||
| public bool ReturnState { get; set; } = false; | |||
| public bool GoBackwards { get; set; } = false; | |||
| public bool Stateful { get; set; } = false; | |||
| public bool Unroll { get; set; } = false; | |||
| public bool TimeMajor { get; set; } = false; | |||
| public Dictionary<string, object> Kwargs { get; set; } = null; | |||
| } | |||
| } | |||
| @@ -0,0 +1,45 @@ | |||
| using System.Collections.Generic; | |||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
| { | |||
| public class RNNArgs : LayerArgs | |||
| { | |||
| public interface IRnnArgCell : ILayer | |||
| { | |||
| object state_size { get; } | |||
| } | |||
| public IRnnArgCell Cell { get; set; } = null; | |||
| public bool ReturnSequences { get; set; } = false; | |||
| public bool ReturnState { get; set; } = false; | |||
| public bool GoBackwards { get; set; } = false; | |||
| public bool Stateful { get; set; } = false; | |||
| public bool Unroll { get; set; } = false; | |||
| public bool TimeMajor { get; set; } = false; | |||
| public Dictionary<string, object> Kwargs { get; set; } = null; | |||
| public int Units { get; set; } | |||
| public Activation Activation { get; set; } | |||
| public Activation RecurrentActivation { get; set; } | |||
| public bool UseBias { get; set; } = true; | |||
| public IInitializer KernelInitializer { get; set; } | |||
| public IInitializer RecurrentInitializer { get; set; } | |||
| public IInitializer BiasInitializer { get; set; } | |||
| // kernel_regularizer=None, | |||
| // recurrent_regularizer=None, | |||
| // bias_regularizer=None, | |||
| // activity_regularizer=None, | |||
| // kernel_constraint=None, | |||
| // recurrent_constraint=None, | |||
| // bias_constraint=None, | |||
| // dropout=0., | |||
| // recurrent_dropout=0., | |||
| // return_sequences=False, | |||
| // return_state=False, | |||
| // go_backwards=False, | |||
| // stateful=False, | |||
| // unroll=False, | |||
| // **kwargs): | |||
| } | |||
| } | |||
| @@ -0,0 +1,7 @@ | |||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
| { | |||
| public class SimpleRNNArgs : RNNArgs | |||
| { | |||
| } | |||
| } | |||
| @@ -1,6 +1,6 @@ | |||
| using System.Collections.Generic; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
| { | |||
| public class StackedRNNCellsArgs : LayerArgs | |||
| { | |||
| @@ -1,30 +0,0 @@ | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class SimpleRNNArgs : RNNArgs | |||
| { | |||
| public int Units { get; set; } | |||
| public Activation Activation { get; set; } | |||
| // units, | |||
| // activation='tanh', | |||
| // use_bias=True, | |||
| // kernel_initializer='glorot_uniform', | |||
| // recurrent_initializer='orthogonal', | |||
| // bias_initializer='zeros', | |||
| // kernel_regularizer=None, | |||
| // recurrent_regularizer=None, | |||
| // bias_regularizer=None, | |||
| // activity_regularizer=None, | |||
| // kernel_constraint=None, | |||
| // recurrent_constraint=None, | |||
| // bias_constraint=None, | |||
| // dropout=0., | |||
| // recurrent_dropout=0., | |||
| // return_sequences=False, | |||
| // return_state=False, | |||
| // go_backwards=False, | |||
| // stateful=False, | |||
| // unroll=False, | |||
| // **kwargs): | |||
| } | |||
| } | |||
| @@ -0,0 +1,12 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.Layers; | |||
| namespace Tensorflow.Keras | |||
| { | |||
| public interface IKerasApi | |||
| { | |||
| public ILayersApi layers { get; } | |||
| } | |||
| } | |||
| @@ -0,0 +1,16 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras | |||
| { | |||
| public interface IPreprocessing | |||
| { | |||
| public ILayer Resizing(int height, int width, string interpolation = "bilinear"); | |||
| public ILayer TextVectorization(Func<Tensor, Tensor> standardize = null, | |||
| string split = "whitespace", | |||
| int max_tokens = -1, | |||
| string output_mode = "int", | |||
| int output_sequence_length = -1); | |||
| } | |||
| } | |||
| @@ -0,0 +1,20 @@ | |||
| using System; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.NumPy; | |||
| using Tensorflow.Operations.Activation; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| public partial interface ILayersApi | |||
| { | |||
| public ILayer ELU(float alpha = 0.1f); | |||
| public ILayer SELU(); | |||
| public ILayer Softmax(Axis axis); | |||
| public ILayer Softplus(); | |||
| public ILayer HardSigmoid(); | |||
| public ILayer Softsign(); | |||
| public ILayer Swish(); | |||
| public ILayer Tanh(); | |||
| public ILayer Exponential(); | |||
| } | |||
| } | |||
| @@ -0,0 +1,28 @@ | |||
| using System; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.NumPy; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| public partial interface ILayersApi | |||
| { | |||
| public ILayer Attention(bool use_scale = false, | |||
| string score_mode = "dot", | |||
| bool causal = false, | |||
| float dropout = 0f); | |||
| public ILayer MultiHeadAttention(int num_heads, | |||
| int key_dim, | |||
| int? value_dim = null, | |||
| float dropout = 0f, | |||
| bool use_bias = true, | |||
| Shape output_shape = null, | |||
| Shape attention_axes = null, | |||
| IInitializer kernel_initializer = null, | |||
| IInitializer bias_initializer = null, | |||
| IRegularizer kernel_regularizer = null, | |||
| IRegularizer bias_regularizer = null, | |||
| IRegularizer activity_regularizer = null, | |||
| Action kernel_constraint = null, | |||
| Action bias_constraint = null); | |||
| } | |||
| } | |||
| @@ -0,0 +1,13 @@ | |||
| using System; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.NumPy; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| public partial interface ILayersApi | |||
| { | |||
| public ILayer Cropping1D(NDArray cropping); | |||
| public ILayer Cropping2D(NDArray cropping, Cropping2DArgs.DataFormat data_format = Cropping2DArgs.DataFormat.channels_last); | |||
| public ILayer Cropping3D(NDArray cropping, Cropping3DArgs.DataFormat data_format = Cropping3DArgs.DataFormat.channels_last); | |||
| } | |||
| } | |||
| @@ -0,0 +1,10 @@ | |||
| using System; | |||
| using Tensorflow.NumPy; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| public partial interface ILayersApi | |||
| { | |||
| public ILayer Concatenate(int axis = -1); | |||
| } | |||
| } | |||
| @@ -0,0 +1,18 @@ | |||
| using System; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.NumPy; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| public partial interface ILayersApi | |||
| { | |||
| public ILayer Reshape(Shape target_shape); | |||
| public ILayer Reshape(object[] target_shape); | |||
| public ILayer UpSampling2D(Shape size = null, | |||
| string data_format = null, | |||
| string interpolation = "nearest"); | |||
| public ILayer ZeroPadding2D(NDArray padding); | |||
| } | |||
| } | |||
| @@ -0,0 +1,169 @@ | |||
| using System; | |||
| using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| public partial interface ILayersApi | |||
| { | |||
| public IPreprocessing preprocessing { get; } | |||
| public ILayer Add(); | |||
| public ILayer AveragePooling2D(Shape pool_size = null, | |||
| Shape strides = null, | |||
| string padding = "valid", | |||
| string data_format = null); | |||
| public ILayer BatchNormalization(int axis = -1, | |||
| float momentum = 0.99f, | |||
| float epsilon = 0.001f, | |||
| bool center = true, | |||
| bool scale = true, | |||
| IInitializer beta_initializer = null, | |||
| IInitializer gamma_initializer = null, | |||
| IInitializer moving_mean_initializer = null, | |||
| IInitializer moving_variance_initializer = null, | |||
| bool trainable = true, | |||
| string name = null, | |||
| bool renorm = false, | |||
| float renorm_momentum = 0.99f); | |||
| public ILayer Conv1D(int filters, | |||
| Shape kernel_size, | |||
| int strides = 1, | |||
| string padding = "valid", | |||
| string data_format = "channels_last", | |||
| int dilation_rate = 1, | |||
| int groups = 1, | |||
| string activation = null, | |||
| bool use_bias = true, | |||
| string kernel_initializer = "glorot_uniform", | |||
| string bias_initializer = "zeros"); | |||
| public ILayer Conv2D(int filters, | |||
| Shape kernel_size = null, | |||
| Shape strides = null, | |||
| string padding = "valid", | |||
| string data_format = null, | |||
| Shape dilation_rate = null, | |||
| int groups = 1, | |||
| Activation activation = null, | |||
| bool use_bias = true, | |||
| IInitializer kernel_initializer = null, | |||
| IInitializer bias_initializer = null, | |||
| IRegularizer kernel_regularizer = null, | |||
| IRegularizer bias_regularizer = null, | |||
| IRegularizer activity_regularizer = null); | |||
| public ILayer Conv2D(int filters, | |||
| Shape kernel_size = null, | |||
| Shape strides = null, | |||
| string padding = "valid", | |||
| string data_format = null, | |||
| Shape dilation_rate = null, | |||
| int groups = 1, | |||
| string activation = null, | |||
| bool use_bias = true, | |||
| string kernel_initializer = "glorot_uniform", | |||
| string bias_initializer = "zeros"); | |||
| public ILayer Dense(int units); | |||
| public ILayer Dense(int units, | |||
| string activation = null, | |||
| Shape input_shape = null); | |||
| public ILayer Dense(int units, | |||
| Activation activation = null, | |||
| IInitializer kernel_initializer = null, | |||
| bool use_bias = true, | |||
| IInitializer bias_initializer = null, | |||
| Shape input_shape = null); | |||
| public ILayer Dropout(float rate, Shape noise_shape = null, int? seed = null); | |||
| public ILayer Embedding(int input_dim, | |||
| int output_dim, | |||
| IInitializer embeddings_initializer = null, | |||
| bool mask_zero = false, | |||
| Shape input_shape = null, | |||
| int input_length = -1); | |||
| public ILayer EinsumDense(string equation, | |||
| Shape output_shape, | |||
| string bias_axes, | |||
| Activation activation = null, | |||
| IInitializer kernel_initializer = null, | |||
| IInitializer bias_initializer = null, | |||
| IRegularizer kernel_regularizer = null, | |||
| IRegularizer bias_regularizer = null, | |||
| IRegularizer activity_regularizer = null, | |||
| Action kernel_constraint = null, | |||
| Action bias_constraint = null); | |||
| public ILayer Flatten(string data_format = null); | |||
| public ILayer GlobalAveragePooling1D(string data_format = "channels_last"); | |||
| public ILayer GlobalAveragePooling2D(); | |||
| public ILayer GlobalAveragePooling2D(string data_format = "channels_last"); | |||
| public ILayer GlobalMaxPooling1D(string data_format = "channels_last"); | |||
| public ILayer GlobalMaxPooling2D(string data_format = "channels_last"); | |||
| public Tensors Input(Shape shape, | |||
| string name = null, | |||
| bool sparse = false, | |||
| bool ragged = false); | |||
| public ILayer InputLayer(Shape input_shape, | |||
| string name = null, | |||
| bool sparse = false, | |||
| bool ragged = false); | |||
| public ILayer LayerNormalization(Axis? axis, | |||
| float epsilon = 1e-3f, | |||
| bool center = true, | |||
| bool scale = true, | |||
| IInitializer beta_initializer = null, | |||
| IInitializer gamma_initializer = null); | |||
| public ILayer LeakyReLU(float alpha = 0.3f); | |||
| public ILayer LSTM(int units, | |||
| Activation activation = null, | |||
| Activation recurrent_activation = null, | |||
| bool use_bias = true, | |||
| IInitializer kernel_initializer = null, | |||
| IInitializer recurrent_initializer = null, | |||
| IInitializer bias_initializer = null, | |||
| bool unit_forget_bias = true, | |||
| float dropout = 0f, | |||
| float recurrent_dropout = 0f, | |||
| int implementation = 2, | |||
| bool return_sequences = false, | |||
| bool return_state = false, | |||
| bool go_backwards = false, | |||
| bool stateful = false, | |||
| bool time_major = false, | |||
| bool unroll = false); | |||
| public ILayer MaxPooling1D(int? pool_size = null, | |||
| int? strides = null, | |||
| string padding = "valid", | |||
| string data_format = null); | |||
| public ILayer MaxPooling2D(Shape pool_size = null, | |||
| Shape strides = null, | |||
| string padding = "valid", | |||
| string data_format = null); | |||
| public ILayer Permute(int[] dims); | |||
| public ILayer Rescaling(float scale, | |||
| float offset = 0, | |||
| Shape input_shape = null); | |||
| public ILayer SimpleRNN(int units, | |||
| string activation = "tanh", | |||
| string kernel_initializer = "glorot_uniform", | |||
| string recurrent_initializer = "orthogonal", | |||
| string bias_initializer = "zeros"); | |||
| public ILayer Subtract(); | |||
| } | |||
| } | |||
| @@ -20,11 +20,11 @@ namespace Tensorflow.NumPy | |||
| Marshal.Copy(y.BufferToArray(), 0, x.TensorDataPointer, (int)x.bytesize); | |||
| } | |||
| public NDArray rand(params int[] shape) | |||
| => throw new NotImplementedException(""); | |||
| public NDArray random(Shape size) | |||
| => uniform(low: 0, high: 1, size: size); | |||
| [AutoNumPy] | |||
| public NDArray randint(int low, int? high = null, Shape size = null, TF_DataType dtype = TF_DataType.TF_INT32) | |||
| public NDArray randint(int low, int? high = null, Shape? size = null, TF_DataType dtype = TF_DataType.TF_INT32) | |||
| { | |||
| if(high == null) | |||
| { | |||
| @@ -41,11 +41,11 @@ namespace Tensorflow.NumPy | |||
| => new NDArray(random_ops.random_normal(shape ?? Shape.Scalar)); | |||
| [AutoNumPy] | |||
| public NDArray normal(float loc = 0.0f, float scale = 1.0f, Shape size = null) | |||
| public NDArray normal(float loc = 0.0f, float scale = 1.0f, Shape? size = null) | |||
| => new NDArray(random_ops.random_normal(size ?? Shape.Scalar, mean: loc, stddev: scale)); | |||
| [AutoNumPy] | |||
| public NDArray uniform(float low = 0.0f, float high = 1.0f, Shape size = null) | |||
| public NDArray uniform(float low = 0.0f, float high = 1.0f, Shape? size = null) | |||
| => new NDArray(random_ops.random_uniform(size ?? Shape.Scalar, low, high)); | |||
| } | |||
| } | |||
| @@ -18,6 +18,7 @@ using System; | |||
| using System.Collections.Generic; | |||
| using Tensorflow.Keras; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Operations; | |||
| using Tensorflow.Util; | |||
| @@ -5,8 +5,8 @@ | |||
| <AssemblyName>Tensorflow.Binding</AssemblyName> | |||
| <RootNamespace>Tensorflow</RootNamespace> | |||
| <TargetTensorFlow>2.2.0</TargetTensorFlow> | |||
| <Version>0.70.2</Version> | |||
| <LangVersion>9.0</LangVersion> | |||
| <Version>0.100.0</Version> | |||
| <LangVersion>10.0</LangVersion> | |||
| <Nullable>enable</Nullable> | |||
| <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | |||
| <Company>SciSharp STACK</Company> | |||
| @@ -20,9 +20,9 @@ | |||
| <Description>Google's TensorFlow full binding in .NET Standard. | |||
| Building, training and infering deep learning models. | |||
| https://tensorflownet.readthedocs.io</Description> | |||
| <AssemblyVersion>0.70.1.0</AssemblyVersion> | |||
| <AssemblyVersion>0.100.0.0</AssemblyVersion> | |||
| <PackageReleaseNotes> | |||
| tf.net 0.70.x and above are based on tensorflow native 2.7.0 | |||
| tf.net 0.100.x and above are based on tensorflow native 2.10.0 | |||
| * Eager Mode is added finally. | |||
| * tf.keras is partially working. | |||
| @@ -35,14 +35,17 @@ https://tensorflownet.readthedocs.io</Description> | |||
| tf.net 0.4x.x aligns with TensorFlow v2.4.1 native library. | |||
| tf.net 0.6x.x aligns with TensorFlow v2.6.x native library. | |||
| tf.net 0.7x.x aligns with TensorFlow v2.7.x native library.</PackageReleaseNotes> | |||
| <FileVersion>0.70.1.0</FileVersion> | |||
| tf.net 0.7x.x aligns with TensorFlow v2.7.x native library. | |||
| tf.net 0.10x.x aligns with TensorFlow v2.10.x native library. | |||
| </PackageReleaseNotes> | |||
| <FileVersion>0.100.0.0</FileVersion> | |||
| <PackageLicenseFile>LICENSE</PackageLicenseFile> | |||
| <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | |||
| <SignAssembly>true</SignAssembly> | |||
| <AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> | |||
| <Platforms>AnyCPU;x64</Platforms> | |||
| <PackageId>TensorFlow.NET</PackageId> | |||
| <Configurations>Debug;Release;GPU</Configurations> | |||
| </PropertyGroup> | |||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||
| @@ -51,6 +54,12 @@ https://tensorflownet.readthedocs.io</Description> | |||
| <PlatformTarget>AnyCPU</PlatformTarget> | |||
| </PropertyGroup> | |||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='GPU|AnyCPU'"> | |||
| <AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||
| <DefineConstants>TRACE;DEBUG;TRACK_TENSOR_LIFE_1</DefineConstants> | |||
| <PlatformTarget>AnyCPU</PlatformTarget> | |||
| </PropertyGroup> | |||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> | |||
| <AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||
| <DefineConstants>TRACE;DEBUG;TRACK_TENSOR_LIFE1</DefineConstants> | |||
| @@ -58,6 +67,13 @@ https://tensorflownet.readthedocs.io</Description> | |||
| <DocumentationFile>TensorFlow.NET.xml</DocumentationFile> | |||
| </PropertyGroup> | |||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='GPU|x64'"> | |||
| <AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||
| <DefineConstants>TRACE;DEBUG;TRACK_TENSOR_LIFE1</DefineConstants> | |||
| <PlatformTarget>x64</PlatformTarget> | |||
| <DocumentationFile>TensorFlow.NET.xml</DocumentationFile> | |||
| </PropertyGroup> | |||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | |||
| <AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||
| </PropertyGroup> | |||
| @@ -20,6 +20,7 @@ using System.Threading; | |||
| using Tensorflow.Contexts; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Gradients; | |||
| using Tensorflow.Keras; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -51,6 +52,8 @@ namespace Tensorflow | |||
| ThreadLocal<IEagerRunner> _runner = new ThreadLocal<IEagerRunner>(() => new EagerRunner()); | |||
| public IEagerRunner Runner => _runner.Value; | |||
| public IKerasApi keras { get; set; } | |||
| public tensorflow() | |||
| { | |||
| Logger = new LoggerConfiguration() | |||
| @@ -2,6 +2,9 @@ | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// Deprecated, will use tf.keras | |||
| /// </summary> | |||
| public static class KerasApi | |||
| { | |||
| public static KerasInterface keras { get; } = new KerasInterface(); | |||
| @@ -10,18 +10,17 @@ using Tensorflow.Keras.Losses; | |||
| using Tensorflow.Keras.Metrics; | |||
| using Tensorflow.Keras.Models; | |||
| using Tensorflow.Keras.Optimizers; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.Keras.Utils; | |||
| using System.Threading; | |||
| namespace Tensorflow.Keras | |||
| { | |||
| public class KerasInterface | |||
| public class KerasInterface : IKerasApi | |||
| { | |||
| public KerasDataset datasets { get; } = new KerasDataset(); | |||
| public Initializers initializers { get; } = new Initializers(); | |||
| public Regularizers regularizers { get; } = new Regularizers(); | |||
| public LayersApi layers { get; } = new LayersApi(); | |||
| public ILayersApi layers { get; } = new LayersApi(); | |||
| public LossesApi losses { get; } = new LossesApi(); | |||
| public Activations activations { get; } = new Activations(); | |||
| public Preprocessing preprocessing { get; } = new Preprocessing(); | |||
| @@ -7,16 +7,16 @@ using static Tensorflow.KerasApi; | |||
| namespace Tensorflow.Keras.Layers { | |||
| public partial class LayersApi { | |||
| public ELU ELU ( float alpha = 0.1f ) | |||
| public ILayer ELU ( float alpha = 0.1f ) | |||
| => new ELU(new ELUArgs { Alpha = alpha }); | |||
| public SELU SELU () | |||
| public ILayer SELU () | |||
| => new SELU(new LayerArgs { }); | |||
| public Softmax Softmax ( Axis axis ) => new Softmax(new SoftmaxArgs { axis = axis }); | |||
| public Softplus Softplus () => new Softplus(new LayerArgs { }); | |||
| public HardSigmoid HardSigmoid () => new HardSigmoid(new LayerArgs { }); | |||
| public Softsign Softsign () => new Softsign(new LayerArgs { }); | |||
| public Swish Swish () => new Swish(new LayerArgs { }); | |||
| public Tanh Tanh () => new Tanh(new LayerArgs { }); | |||
| public Exponential Exponential () => new Exponential(new LayerArgs { }); | |||
| public ILayer Softmax ( Axis axis ) => new Softmax(new SoftmaxArgs { axis = axis }); | |||
| public ILayer Softplus () => new Softplus(new LayerArgs { }); | |||
| public ILayer HardSigmoid () => new HardSigmoid(new LayerArgs { }); | |||
| public ILayer Softsign () => new Softsign(new LayerArgs { }); | |||
| public ILayer Swish () => new Swish(new LayerArgs { }); | |||
| public ILayer Tanh () => new Tanh(new LayerArgs { }); | |||
| public ILayer Exponential () => new Exponential(new LayerArgs { }); | |||
| } | |||
| } | |||
| @@ -10,7 +10,7 @@ namespace Tensorflow.Keras.Layers | |||
| { | |||
| public partial class LayersApi | |||
| { | |||
| public Attention Attention(bool use_scale = false, | |||
| public ILayer Attention(bool use_scale = false, | |||
| string score_mode = "dot", | |||
| bool causal = false, | |||
| float dropout = 0f) => | |||
| @@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Layers | |||
| causal = causal, | |||
| dropout = dropout | |||
| }); | |||
| public MultiHeadAttention MultiHeadAttention(int num_heads, | |||
| public ILayer MultiHeadAttention(int num_heads, | |||
| int key_dim, | |||
| int? value_dim = null, | |||
| float dropout = 0f, | |||
| @@ -10,7 +10,7 @@ namespace Tensorflow.Keras.Layers { | |||
| /// Cropping layer for 1D input | |||
| /// </summary> | |||
| /// <param name="cropping">cropping size</param> | |||
| public Cropping1D Cropping1D ( NDArray cropping ) | |||
| public ILayer Cropping1D ( NDArray cropping ) | |||
| => new Cropping1D(new CroppingArgs { | |||
| cropping = cropping | |||
| }); | |||
| @@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Layers { | |||
| /// <summary> | |||
| /// Cropping layer for 2D input <br/> | |||
| /// </summary> | |||
| public Cropping2D Cropping2D ( NDArray cropping, Cropping2DArgs.DataFormat data_format = Cropping2DArgs.DataFormat.channels_last ) | |||
| public ILayer Cropping2D ( NDArray cropping, Cropping2DArgs.DataFormat data_format = Cropping2DArgs.DataFormat.channels_last ) | |||
| => new Cropping2D(new Cropping2DArgs { | |||
| cropping = cropping, | |||
| data_format = data_format | |||
| @@ -27,7 +27,7 @@ namespace Tensorflow.Keras.Layers { | |||
| /// <summary> | |||
| /// Cropping layer for 3D input <br/> | |||
| /// </summary> | |||
| public Cropping3D Cropping3D ( NDArray cropping, Cropping3DArgs.DataFormat data_format = Cropping3DArgs.DataFormat.channels_last ) | |||
| public ILayer Cropping3D ( NDArray cropping, Cropping3DArgs.DataFormat data_format = Cropping3DArgs.DataFormat.channels_last ) | |||
| => new Cropping3D(new Cropping3DArgs { | |||
| cropping = cropping, | |||
| data_format = data_format | |||
| @@ -13,7 +13,7 @@ namespace Tensorflow.Keras.Layers | |||
| /// </summary> | |||
| /// <param name="axis">Axis along which to concatenate.</param> | |||
| /// <returns></returns> | |||
| public Concatenate Concatenate(int axis = -1) | |||
| public ILayer Concatenate(int axis = -1) | |||
| => new Concatenate(new MergeArgs | |||
| { | |||
| Axis = axis | |||
| @@ -11,7 +11,7 @@ namespace Tensorflow.Keras.Layers { | |||
| /// </summary> | |||
| /// <param name="padding"></param> | |||
| /// <returns></returns> | |||
| public ZeroPadding2D ZeroPadding2D ( NDArray padding ) | |||
| public ILayer ZeroPadding2D ( NDArray padding ) | |||
| => new ZeroPadding2D(new ZeroPadding2DArgs { | |||
| Padding = padding | |||
| }); | |||
| @@ -24,7 +24,7 @@ namespace Tensorflow.Keras.Layers { | |||
| /// <param name="data_format"></param> | |||
| /// <param name="interpolation"></param> | |||
| /// <returns></returns> | |||
| public UpSampling2D UpSampling2D ( Shape size = null, | |||
| public ILayer UpSampling2D ( Shape size = null, | |||
| string data_format = null, | |||
| string interpolation = "nearest" ) | |||
| => new UpSampling2D(new UpSampling2DArgs { | |||
| @@ -34,7 +34,7 @@ namespace Tensorflow.Keras.Layers { | |||
| /// <summary> | |||
| /// Permutes the dimensions of the input according to a given pattern. | |||
| /// </summary> | |||
| public Permute Permute ( int[] dims ) | |||
| public ILayer Permute ( int[] dims ) | |||
| => new Permute(new PermuteArgs { | |||
| dims = dims | |||
| }); | |||
| @@ -44,12 +44,12 @@ namespace Tensorflow.Keras.Layers { | |||
| /// </summary> | |||
| /// <param name="target_shape"></param> | |||
| /// <returns></returns> | |||
| public Reshape Reshape ( Shape target_shape ) | |||
| => new Reshape(new ReshapeArgs { | |||
| TargetShape = target_shape | |||
| }); | |||
| public ILayer Reshape ( Shape target_shape ) | |||
| => new Reshape(new ReshapeArgs { | |||
| TargetShape = target_shape | |||
| }); | |||
| public Reshape Reshape ( object[] target_shape ) | |||
| public ILayer Reshape ( object[] target_shape ) | |||
| => new Reshape(new ReshapeArgs { | |||
| TargetShapeObjects = target_shape | |||
| }); | |||
| @@ -1,16 +1,18 @@ | |||
| using System; | |||
| using Tensorflow.NumPy; | |||
| using System.Collections.Generic; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.ArgsDefinition.Lstm; | |||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Layers.Lstm; | |||
| using Tensorflow.Keras.Layers.Rnn; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.KerasApi; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| public partial class LayersApi | |||
| public partial class LayersApi : ILayersApi | |||
| { | |||
| public Preprocessing preprocessing { get; } = new Preprocessing(); | |||
| public IPreprocessing preprocessing { get; } = new Preprocessing(); | |||
| /// <summary> | |||
| /// Layer that normalizes its inputs. | |||
| @@ -38,7 +40,7 @@ namespace Tensorflow.Keras.Layers | |||
| /// Note that momentum is still applied to get the means and variances for inference. | |||
| /// </param> | |||
| /// <returns>Tensor of the same shape as input.</returns> | |||
| public BatchNormalization BatchNormalization(int axis = -1, | |||
| public ILayer BatchNormalization(int axis = -1, | |||
| float momentum = 0.99f, | |||
| float epsilon = 0.001f, | |||
| bool center = true, | |||
| @@ -84,7 +86,7 @@ namespace Tensorflow.Keras.Layers | |||
| /// <param name="kernel_initializer">Initializer for the kernel weights matrix (see keras.initializers).</param> | |||
| /// <param name="bias_initializer">Initializer for the bias vector (see keras.initializers).</param> | |||
| /// <returns>A tensor of rank 3 representing activation(conv1d(inputs, kernel) + bias).</returns> | |||
| public Conv1D Conv1D(int filters, | |||
| public ILayer Conv1D(int filters, | |||
| Shape kernel_size, | |||
| int strides = 1, | |||
| string padding = "valid", | |||
| @@ -131,7 +133,7 @@ namespace Tensorflow.Keras.Layers | |||
| /// <param name="bias_regularizer">Regularizer function applied to the bias vector (see keras.regularizers).</param> | |||
| /// <param name="activity_regularizer">Regularizer function applied to the output of the layer (its "activation") (see keras.regularizers).</param> | |||
| /// <returns>A tensor of rank 4+ representing activation(conv2d(inputs, kernel) + bias).</returns> | |||
| public Conv2D Conv2D(int filters, | |||
| public ILayer Conv2D(int filters, | |||
| Shape kernel_size = null, | |||
| Shape strides = null, | |||
| string padding = "valid", | |||
| @@ -184,7 +186,7 @@ namespace Tensorflow.Keras.Layers | |||
| /// <param name="bias_regularizer">The name of the regularizer function applied to the bias vector (see keras.regularizers).</param> | |||
| /// <param name="activity_regularizer">The name of the regularizer function applied to the output of the layer (its "activation") (see keras.regularizers).</param> | |||
| /// <returns>A tensor of rank 4+ representing activation(conv2d(inputs, kernel) + bias).</returns> | |||
| public Conv2D Conv2D(int filters, | |||
| public ILayer Conv2D(int filters, | |||
| Shape kernel_size = null, | |||
| Shape strides = null, | |||
| string padding = "valid", | |||
| @@ -228,7 +230,7 @@ namespace Tensorflow.Keras.Layers | |||
| /// <param name="bias_regularizer">The name of the regularizer function applied to the bias vector (see keras.regularizers).</param> | |||
| /// <param name="activity_regularizer">The name of the regularizer function applied to the output of the layer (its "activation") (see keras.regularizers).</param> | |||
| /// <returns>A tensor of rank 4+ representing activation(conv2d(inputs, kernel) + bias).</returns> | |||
| public Conv2DTranspose Conv2DTranspose(int filters, | |||
| public ILayer Conv2DTranspose(int filters, | |||
| Shape kernel_size = null, | |||
| Shape strides = null, | |||
| string output_padding = "valid", | |||
| @@ -270,7 +272,7 @@ namespace Tensorflow.Keras.Layers | |||
| /// <param name="bias_initializer">Initializer for the bias vector.</param> | |||
| /// <param name="input_shape">N-D tensor with shape: (batch_size, ..., input_dim). The most common situation would be a 2D input with shape (batch_size, input_dim).</param> | |||
| /// <returns>N-D tensor with shape: (batch_size, ..., units). For instance, for a 2D input with shape (batch_size, input_dim), the output would have shape (batch_size, units).</returns> | |||
| public Dense Dense(int units, | |||
| public ILayer Dense(int units, | |||
| Activation activation = null, | |||
| IInitializer kernel_initializer = null, | |||
| bool use_bias = true, | |||
| @@ -294,7 +296,7 @@ namespace Tensorflow.Keras.Layers | |||
| /// </summary> | |||
| /// <param name="units">Positive integer, dimensionality of the output space.</param> | |||
| /// <returns>N-D tensor with shape: (batch_size, ..., units). For instance, for a 2D input with shape (batch_size, input_dim), the output would have shape (batch_size, units).</returns> | |||
| public Dense Dense(int units) | |||
| public ILayer Dense(int units) | |||
| => new Dense(new DenseArgs | |||
| { | |||
| Units = units, | |||
| @@ -312,7 +314,7 @@ namespace Tensorflow.Keras.Layers | |||
| /// <param name="activation">Activation function to use. If you don't specify anything, no activation is applied (ie. "linear" activation: a(x) = x).</param> | |||
| /// <param name="input_shape">N-D tensor with shape: (batch_size, ..., input_dim). The most common situation would be a 2D input with shape (batch_size, input_dim).</param> | |||
| /// <returns>N-D tensor with shape: (batch_size, ..., units). For instance, for a 2D input with shape (batch_size, input_dim), the output would have shape (batch_size, units).</returns> | |||
| public Dense Dense(int units, | |||
| public ILayer Dense(int units, | |||
| string activation = null, | |||
| Shape input_shape = null) | |||
| => new Dense(new DenseArgs | |||
| @@ -364,7 +366,7 @@ namespace Tensorflow.Keras.Layers | |||
| } | |||
| public EinsumDense EinsumDense(string equation, | |||
| public ILayer EinsumDense(string equation, | |||
| Shape output_shape, | |||
| string bias_axes, | |||
| Activation activation = null, | |||
| @@ -402,7 +404,7 @@ namespace Tensorflow.Keras.Layers | |||
| /// </param> | |||
| /// <param name="seed">An integer to use as random seed.</param> | |||
| /// <returns></returns> | |||
| public Dropout Dropout(float rate, Shape noise_shape = null, int? seed = null) | |||
| public ILayer Dropout(float rate, Shape noise_shape = null, int? seed = null) | |||
| => new Dropout(new DropoutArgs | |||
| { | |||
| Rate = rate, | |||
| @@ -421,7 +423,7 @@ namespace Tensorflow.Keras.Layers | |||
| /// <param name="embeddings_initializer">Initializer for the embeddings matrix (see keras.initializers).</param> | |||
| /// <param name="mask_zero"></param> | |||
| /// <returns></returns> | |||
| public Embedding Embedding(int input_dim, | |||
| public ILayer Embedding(int input_dim, | |||
| int output_dim, | |||
| IInitializer embeddings_initializer = null, | |||
| bool mask_zero = false, | |||
| @@ -446,7 +448,7 @@ namespace Tensorflow.Keras.Layers | |||
| /// If you never set it, then it will be "channels_last". | |||
| /// </param> | |||
| /// <returns></returns> | |||
| public Flatten Flatten(string data_format = null) | |||
| public ILayer Flatten(string data_format = null) | |||
| => new Flatten(new FlattenArgs | |||
| { | |||
| DataFormat = data_format | |||
| @@ -482,7 +484,7 @@ namespace Tensorflow.Keras.Layers | |||
| return input_layer.InboundNodes[0].Outputs; | |||
| } | |||
| public InputLayer InputLayer(Shape input_shape, | |||
| public ILayer InputLayer(Shape input_shape, | |||
| string name = null, | |||
| bool sparse = false, | |||
| bool ragged = false) | |||
| @@ -502,7 +504,7 @@ namespace Tensorflow.Keras.Layers | |||
| /// <param name="padding"></param> | |||
| /// <param name="data_format"></param> | |||
| /// <returns></returns> | |||
| public AveragePooling2D AveragePooling2D(Shape pool_size = null, | |||
| public ILayer AveragePooling2D(Shape pool_size = null, | |||
| Shape strides = null, | |||
| string padding = "valid", | |||
| string data_format = null) | |||
| @@ -527,7 +529,7 @@ namespace Tensorflow.Keras.Layers | |||
| /// channels_last corresponds to inputs with shape (batch, steps, features) while channels_first corresponds to inputs with shape (batch, features, steps). | |||
| /// </param> | |||
| /// <returns></returns> | |||
| public MaxPooling1D MaxPooling1D(int? pool_size = null, | |||
| public ILayer MaxPooling1D(int? pool_size = null, | |||
| int? strides = null, | |||
| string padding = "valid", | |||
| string data_format = null) | |||
| @@ -564,7 +566,7 @@ namespace Tensorflow.Keras.Layers | |||
| /// It defaults to the image_data_format value found in your Keras config file at ~/.keras/keras.json. | |||
| /// If you never set it, then it will be "channels_last"</param> | |||
| /// <returns></returns> | |||
| public MaxPooling2D MaxPooling2D(Shape pool_size = null, | |||
| public ILayer MaxPooling2D(Shape pool_size = null, | |||
| Shape strides = null, | |||
| string padding = "valid", | |||
| string data_format = null) | |||
| @@ -618,7 +620,7 @@ namespace Tensorflow.Keras.Layers | |||
| return layer.Apply(inputs); | |||
| } | |||
| public Layer LayerNormalization(Axis? axis, | |||
| public ILayer LayerNormalization(Axis? axis, | |||
| float epsilon = 1e-3f, | |||
| bool center = true, | |||
| bool scale = true, | |||
| @@ -638,45 +640,30 @@ namespace Tensorflow.Keras.Layers | |||
| /// </summary> | |||
| /// <param name="alpha">Negative slope coefficient.</param> | |||
| /// <returns></returns> | |||
| public Layer LeakyReLU(float alpha = 0.3f) | |||
| public ILayer LeakyReLU(float alpha = 0.3f) | |||
| => new LeakyReLu(new LeakyReLuArgs | |||
| { | |||
| Alpha = alpha | |||
| }); | |||
| /// <summary> | |||
| /// Fully-connected RNN where the output is to be fed back to input. | |||
| /// </summary> | |||
| /// <param name="units">Positive integer, dimensionality of the output space.</param> | |||
| /// <returns></returns> | |||
| public Layer SimpleRNN(int units) => SimpleRNN(units, "tanh"); | |||
| /// <summary> | |||
| /// Fully-connected RNN where the output is to be fed back to input. | |||
| /// </summary> | |||
| /// <param name="units">Positive integer, dimensionality of the output space.</param> | |||
| /// <param name="activation">Activation function to use. If you pass null, no activation is applied (ie. "linear" activation: a(x) = x).</param> | |||
| /// <returns></returns> | |||
| public Layer SimpleRNN(int units, | |||
| Activation activation = null) | |||
| => new SimpleRNN(new SimpleRNNArgs | |||
| { | |||
| Units = units, | |||
| Activation = activation | |||
| }); | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <param name="units">Positive integer, dimensionality of the output space.</param> | |||
| /// <param name="activation">The name of the activation function to use. Default: hyperbolic tangent (tanh)..</param> | |||
| /// <returns></returns> | |||
| public Layer SimpleRNN(int units, | |||
| string activation = "tanh") | |||
| public ILayer SimpleRNN(int units, | |||
| string activation = "tanh", | |||
| string kernel_initializer = "glorot_uniform", | |||
| string recurrent_initializer = "orthogonal", | |||
| string bias_initializer = "zeros") | |||
| => new SimpleRNN(new SimpleRNNArgs | |||
| { | |||
| Units = units, | |||
| Activation = GetActivationByName(activation) | |||
| Activation = GetActivationByName(activation), | |||
| KernelInitializer = GetInitializerByName(kernel_initializer), | |||
| RecurrentInitializer= GetInitializerByName(recurrent_initializer), | |||
| BiasInitializer= GetInitializerByName(bias_initializer) | |||
| }); | |||
| /// <summary> | |||
| @@ -706,7 +693,7 @@ namespace Tensorflow.Keras.Layers | |||
| /// although it tends to be more memory-intensive. Unrolling is only suitable for short sequences. | |||
| /// </param> | |||
| /// <returns></returns> | |||
| public Layer LSTM(int units, | |||
| public ILayer LSTM(int units, | |||
| Activation activation = null, | |||
| Activation recurrent_activation = null, | |||
| bool use_bias = true, | |||
| @@ -749,7 +736,7 @@ namespace Tensorflow.Keras.Layers | |||
| /// <param name="offset"></param> | |||
| /// <param name="input_shape"></param> | |||
| /// <returns></returns> | |||
| public Rescaling Rescaling(float scale, | |||
| public ILayer Rescaling(float scale, | |||
| float offset = 0, | |||
| Shape input_shape = null) | |||
| => new Rescaling(new RescalingArgs | |||
| @@ -763,21 +750,21 @@ namespace Tensorflow.Keras.Layers | |||
| /// | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public Add Add() | |||
| public ILayer Add() | |||
| => new Add(new MergeArgs { }); | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public Subtract Subtract() | |||
| public ILayer Subtract() | |||
| => new Subtract(new MergeArgs { }); | |||
| /// <summary> | |||
| /// Global max pooling operation for spatial data. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public GlobalAveragePooling2D GlobalAveragePooling2D() | |||
| public ILayer GlobalAveragePooling2D() | |||
| => new GlobalAveragePooling2D(new Pooling2DArgs { }); | |||
| /// <summary> | |||
| @@ -787,7 +774,7 @@ namespace Tensorflow.Keras.Layers | |||
| /// channels_last corresponds to inputs with shape (batch, steps, features) while channels_first corresponds to inputs with shape (batch, features, steps). | |||
| /// </param> | |||
| /// <returns></returns> | |||
| public GlobalAveragePooling1D GlobalAveragePooling1D(string data_format = "channels_last") | |||
| public ILayer GlobalAveragePooling1D(string data_format = "channels_last") | |||
| => new GlobalAveragePooling1D(new Pooling1DArgs { DataFormat = data_format }); | |||
| /// <summary> | |||
| @@ -796,7 +783,7 @@ namespace Tensorflow.Keras.Layers | |||
| /// <param name="data_format">A string, one of channels_last (default) or channels_first. The ordering of the dimensions in the inputs. | |||
| /// channels_last corresponds to inputs with shape (batch, height, width, channels) while channels_first corresponds to inputs with shape (batch, channels, height, width).</param> | |||
| /// <returns></returns> | |||
| public GlobalAveragePooling2D GlobalAveragePooling2D(string data_format = "channels_last") | |||
| public ILayer GlobalAveragePooling2D(string data_format = "channels_last") | |||
| => new GlobalAveragePooling2D(new Pooling2DArgs { DataFormat = data_format }); | |||
| /// <summary> | |||
| @@ -807,7 +794,7 @@ namespace Tensorflow.Keras.Layers | |||
| /// channels_last corresponds to inputs with shape (batch, steps, features) while channels_first corresponds to inputs with shape (batch, features, steps). | |||
| /// </param> | |||
| /// <returns></returns> | |||
| public GlobalMaxPooling1D GlobalMaxPooling1D(string data_format = "channels_last") | |||
| public ILayer GlobalMaxPooling1D(string data_format = "channels_last") | |||
| => new GlobalMaxPooling1D(new Pooling1DArgs { DataFormat = data_format }); | |||
| /// <summary> | |||
| @@ -816,7 +803,7 @@ namespace Tensorflow.Keras.Layers | |||
| /// <param name="data_format">A string, one of channels_last (default) or channels_first. The ordering of the dimensions in the inputs. | |||
| /// channels_last corresponds to inputs with shape (batch, height, width, channels) while channels_first corresponds to inputs with shape (batch, channels, height, width).</param> | |||
| /// <returns></returns> | |||
| public GlobalMaxPooling2D GlobalMaxPooling2D(string data_format = "channels_last") | |||
| public ILayer GlobalMaxPooling2D(string data_format = "channels_last") | |||
| => new GlobalMaxPooling2D(new Pooling2DArgs { DataFormat = data_format }); | |||
| @@ -848,6 +835,7 @@ namespace Tensorflow.Keras.Layers | |||
| "glorot_uniform" => tf.glorot_uniform_initializer, | |||
| "zeros" => tf.zeros_initializer, | |||
| "ones" => tf.ones_initializer, | |||
| "orthogonal" => tf.orthogonal_initializer, | |||
| _ => tf.glorot_uniform_initializer | |||
| }; | |||
| } | |||
| @@ -1,8 +1,9 @@ | |||
| using System.Linq; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.ArgsDefinition.Lstm; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Layers.Rnn; | |||
| namespace Tensorflow.Keras.Layers | |||
| namespace Tensorflow.Keras.Layers.Lstm | |||
| { | |||
| /// <summary> | |||
| /// Long Short-Term Memory layer - Hochreiter 1997. | |||
| @@ -1,7 +1,7 @@ | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.ArgsDefinition.Lstm; | |||
| using Tensorflow.Keras.Engine; | |||
| namespace Tensorflow.Keras.Layers | |||
| namespace Tensorflow.Keras.Layers.Lstm | |||
| { | |||
| public class LSTMCell : Layer | |||
| { | |||
| @@ -1,10 +1,12 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Layers.Lstm; | |||
| // from tensorflow.python.distribute import distribution_strategy_context as ds_context; | |||
| namespace Tensorflow.Keras.Layers | |||
| namespace Tensorflow.Keras.Layers.Rnn | |||
| { | |||
| public class RNN : Layer | |||
| { | |||
| @@ -14,6 +16,8 @@ namespace Tensorflow.Keras.Layers | |||
| private object _states = null; | |||
| private object constants_spec = null; | |||
| private int _num_constants = 0; | |||
| protected IVariableV1 kernel; | |||
| protected IVariableV1 bias; | |||
| public RNN(RNNArgs args) : base(PreConstruct(args)) | |||
| { | |||
| @@ -0,0 +1,31 @@ | |||
| using System.Data; | |||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
| using Tensorflow.Operations.Activation; | |||
| using static HDF.PInvoke.H5Z; | |||
| using static Tensorflow.ApiDef.Types; | |||
| namespace Tensorflow.Keras.Layers.Rnn | |||
| { | |||
| public class SimpleRNN : RNN | |||
| { | |||
| SimpleRNNArgs args; | |||
| SimpleRNNCell cell; | |||
| public SimpleRNN(SimpleRNNArgs args) : base(args) | |||
| { | |||
| this.args = args; | |||
| } | |||
| protected override void build(Tensors inputs) | |||
| { | |||
| var input_shape = inputs.shape; | |||
| var input_dim = input_shape[-1]; | |||
| kernel = add_weight("kernel", (input_shape[-1], args.Units), | |||
| initializer: args.KernelInitializer | |||
| //regularizer = self.kernel_regularizer, | |||
| //constraint = self.kernel_constraint, | |||
| //caching_device = default_caching_device, | |||
| ); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,21 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
| using Tensorflow.Keras.Engine; | |||
| namespace Tensorflow.Keras.Layers.Rnn | |||
| { | |||
| public class SimpleRNNCell : Layer | |||
| { | |||
| public SimpleRNNCell(SimpleRNNArgs args) : base(args) | |||
| { | |||
| } | |||
| protected override void build(Tensors inputs) | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -2,9 +2,10 @@ | |||
| using System.Collections.Generic; | |||
| using System.ComponentModel; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
| using Tensorflow.Keras.Engine; | |||
| namespace Tensorflow.Keras.Layers | |||
| namespace Tensorflow.Keras.Layers.Rnn | |||
| { | |||
| public class StackedRNNCells : Layer, RNNArgs.IRnnArgCell | |||
| { | |||
| @@ -1,14 +0,0 @@ | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| public class SimpleRNN : RNN | |||
| { | |||
| public SimpleRNN(RNNArgs args) : base(args) | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -15,7 +15,7 @@ namespace Tensorflow.Keras | |||
| /// <param name="width"></param> | |||
| /// <param name="interpolation"></param> | |||
| /// <returns></returns> | |||
| public Resizing Resizing(int height, int width, string interpolation = "bilinear") | |||
| public ILayer Resizing(int height, int width, string interpolation = "bilinear") | |||
| => new Resizing(new ResizingArgs | |||
| { | |||
| Height = height, | |||
| @@ -5,7 +5,7 @@ using Tensorflow.Keras.Preprocessings; | |||
| namespace Tensorflow.Keras | |||
| { | |||
| public partial class Preprocessing | |||
| public partial class Preprocessing : IPreprocessing | |||
| { | |||
| public Sequence sequence => new Sequence(); | |||
| public DatasetUtils dataset_utils => new DatasetUtils(); | |||
| @@ -14,7 +14,7 @@ namespace Tensorflow.Keras | |||
| private static TextApi _text = new TextApi(); | |||
| public TextVectorization TextVectorization(Func<Tensor, Tensor> standardize = null, | |||
| public ILayer TextVectorization(Func<Tensor, Tensor> standardize = null, | |||
| string split = "whitespace", | |||
| int max_tokens = -1, | |||
| string output_mode = "int", | |||
| @@ -3,11 +3,11 @@ | |||
| <PropertyGroup> | |||
| <TargetFramework>netstandard2.0</TargetFramework> | |||
| <AssemblyName>Tensorflow.Keras</AssemblyName> | |||
| <LangVersion>9.0</LangVersion> | |||
| <LangVersion>10.0</LangVersion> | |||
| <Nullable>enable</Nullable> | |||
| <RootNamespace>Tensorflow.Keras</RootNamespace> | |||
| <Platforms>AnyCPU;x64</Platforms> | |||
| <Version>0.7.0</Version> | |||
| <Version>0.10.0</Version> | |||
| <Authors>Haiping Chen</Authors> | |||
| <Product>Keras for .NET</Product> | |||
| <Copyright>Apache 2.0, Haiping Chen 2021</Copyright> | |||
| @@ -37,9 +37,10 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||
| <RepositoryType>Git</RepositoryType> | |||
| <SignAssembly>true</SignAssembly> | |||
| <AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> | |||
| <AssemblyVersion>0.7.0.0</AssemblyVersion> | |||
| <FileVersion>0.7.0.0</FileVersion> | |||
| <AssemblyVersion>0.10.0.0</AssemblyVersion> | |||
| <FileVersion>0.10.0.0</FileVersion> | |||
| <PackageLicenseFile>LICENSE</PackageLicenseFile> | |||
| <Configurations>Debug;Release;GPU</Configurations> | |||
| </PropertyGroup> | |||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||
| @@ -47,6 +48,11 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||
| <AllowUnsafeBlocks>false</AllowUnsafeBlocks> | |||
| </PropertyGroup> | |||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='GPU|AnyCPU'"> | |||
| <DefineConstants>DEBUG;TRACE</DefineConstants> | |||
| <AllowUnsafeBlocks>false</AllowUnsafeBlocks> | |||
| </PropertyGroup> | |||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | |||
| <AllowUnsafeBlocks>false</AllowUnsafeBlocks> | |||
| </PropertyGroup> | |||
| @@ -55,6 +61,10 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||
| <DocumentationFile>Tensorflow.Keras.xml</DocumentationFile> | |||
| </PropertyGroup> | |||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='GPU|x64'"> | |||
| <DocumentationFile>Tensorflow.Keras.xml</DocumentationFile> | |||
| </PropertyGroup> | |||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> | |||
| <DefineConstants /> | |||
| </PropertyGroup> | |||
| @@ -134,7 +134,7 @@ namespace Tensorflow.Keras | |||
| /// <param name="data_format"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public Tensor max_pooling2d(Tensor inputs, | |||
| public Tensor MaxPooling2D(Tensor inputs, | |||
| int[] pool_size, | |||
| int[] strides, | |||
| string padding = "valid", | |||
| @@ -0,0 +1,16 @@ | |||
| { | |||
| // Use IntelliSense to learn about possible attributes. | |||
| // Hover to view descriptions of existing attributes. | |||
| // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 | |||
| "version": "0.2.0", | |||
| "configurations": [ | |||
| { | |||
| "name": "Python: Current File", | |||
| "type": "python", | |||
| "request": "launch", | |||
| "program": "${file}", | |||
| "console": "integratedTerminal", | |||
| "justMyCode": false | |||
| } | |||
| ] | |||
| } | |||
| @@ -0,0 +1,15 @@ | |||
| import numpy as np | |||
| import tensorflow as tf | |||
| # tf.experimental.numpy | |||
| inputs = np.random.random([32, 10, 8]).astype(np.float32) | |||
| simple_rnn = tf.keras.layers.SimpleRNN(4) | |||
| output = simple_rnn(inputs) # The output has shape `[32, 4]`. | |||
| simple_rnn = tf.keras.layers.SimpleRNN( | |||
| 4, return_sequences=True, return_state=True) | |||
| # whole_sequence_output has shape `[32, 10, 4]`. | |||
| # final_state has shape `[32, 4]`. | |||
| whole_sequence_output, final_state = simple_rnn(inputs) | |||
| @@ -83,7 +83,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||
| { 2.5f, 2.6f, 2.7f, 2.8f }, | |||
| { 3.5f, 3.6f, 3.7f, 3.8f } | |||
| } }, dtype: np.float32); | |||
| var attention_layer = keras.layers.Attention(); | |||
| var attention_layer = (Attention)keras.layers.Attention(); | |||
| //attention_layer.build(((1, 2, 4), (1, 3, 4))); | |||
| var actual = attention_layer._calculate_scores(query: q, key: k); | |||
| // Expected tensor of shape [1, 2, 3]. | |||
| @@ -116,7 +116,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||
| { 2.5f, 2.6f, 2.7f, 2.8f }, | |||
| { 3.5f, 3.6f, 3.7f, 3.8f } | |||
| } }, dtype: np.float32); | |||
| var attention_layer = keras.layers.Attention(score_mode: "concat"); | |||
| var attention_layer = (Attention)keras.layers.Attention(score_mode: "concat"); | |||
| //attention_layer.concat_score_weight = 1; | |||
| attention_layer.concat_score_weight = base_layer_utils.make_variable(new VariableArgs() { | |||
| Name = "concat_score_weight", | |||
| @@ -148,10 +148,9 @@ namespace TensorFlowNET.Keras.UnitTest | |||
| } | |||
| [TestMethod] | |||
| [Ignore] | |||
| public void SimpleRNN() | |||
| { | |||
| var inputs = np.random.rand(32, 10, 8).astype(np.float32); | |||
| var inputs = np.random.random((32, 10, 8)).astype(np.float32); | |||
| var simple_rnn = keras.layers.SimpleRNN(4); | |||
| var output = simple_rnn.Apply(inputs); | |||
| Assert.AreEqual((32, 4), output.shape); | |||
| @@ -4,7 +4,7 @@ | |||
| <TargetFramework>net6.0</TargetFramework> | |||
| <IsPackable>false</IsPackable> | |||
| <LangVersion>11.0</LangVersion> | |||
| <Platforms>AnyCPU;x64</Platforms> | |||
| </PropertyGroup> | |||
| @@ -11,7 +11,7 @@ | |||
| <AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> | |||
| <LangVersion>9.0</LangVersion> | |||
| <LangVersion>11.0</LangVersion> | |||
| <Platforms>AnyCPU;x64</Platforms> | |||
| </PropertyGroup> | |||