diff --git a/src/TensorFlowNET.Core/APIs/tf.init.cs b/src/TensorFlowNET.Core/APIs/tf.init.cs index f6fa380c..c9294653 100644 --- a/src/TensorFlowNET.Core/APIs/tf.init.cs +++ b/src/TensorFlowNET.Core/APIs/tf.init.cs @@ -52,5 +52,13 @@ namespace Tensorflow stddev: stddev, seed: seed, dtype: dtype); + + public IInitializer random_normal_initializer(float mean = 0.0f, + float stddev = 1.0f, + int? seed = null, + TF_DataType dtype = TF_DataType.DtInvalid) => new RandomNormal(mean: mean, + stddev: stddev, + seed: seed, + dtype: dtype); } } diff --git a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs new file mode 100644 index 00000000..5b1b5713 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs @@ -0,0 +1,51 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations.Initializers +{ + public class RandomNormal : IInitializer + { + private float mean; + private float stddev; + private int? seed; + private TF_DataType dtype; + + public RandomNormal(float mean = 0.0f, + float stddev = 1.0f, + int? seed = null, + TF_DataType dtype = TF_DataType.TF_FLOAT) + { + this.mean = mean; + this.stddev = stddev; + this.seed = seed; + this.dtype = dtype; + } + + public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) + { + throw new NotImplementedException(); + } + + public object get_config() + { + throw new NotImplementedException(); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs index 63e0fca1..b189bb83 100644 --- a/src/TensorFlowNET.Core/Operations/nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs @@ -116,6 +116,19 @@ namespace Tensorflow return _softmax(logits, gen_nn_ops.log_softmax, axis, name); } + public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null) + { + return tf_with(ops.name_scope(name, "LeakyRelu", new { features, alpha }), scope => + { + name = scope; + features = ops.convert_to_tensor(features, name: "features"); + if (features.dtype.is_integer()) + features = math_ops.cast(features, dtypes.float32); + return gen_nn_ops.leaky_relu(features, alpha: alpha, name: name); + //return math_ops.maximum(alpha * features, features, name: name); + }); + } + /// /// Performs the max pooling on the input. ///