diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs
index 4c7ba775..025d3d57 100644
--- a/src/TensorFlowNET.Core/APIs/tf.nn.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs
@@ -97,6 +97,9 @@ namespace Tensorflow
throw new NotImplementedException("");
}
+ public static Tensor elu(Tensor features, string name = null)
+ => gen_nn_ops.elu(features, name: name);
+
public static (Tensor, Tensor) moments(Tensor x,
int[] axes,
string name = null,
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
index 7e281d46..9244ccff 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
@@ -139,6 +139,27 @@ namespace Tensorflow.Operations
});
return _op.outputs[0];
+ }
+
+ ///
+ /// Computes exponential linear: exp(features) - 1 if < 0, features otherwise.
+ ///
+ ///
+ ///
+ ///
+ /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Elu'.
+ ///
+ ///
+ /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result.
+ ///
+ ///
+ /// See [Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs)
+ /// ](http://arxiv.org/abs/1511.07289)
+ ///
+ public static Tensor elu(Tensor features, string name = "Elu")
+ {
+ var op = _op_def_lib._apply_op_helper("Elu", name: name, args: new { features });
+ return op.output;
}
public static Tensor[] _fused_batch_norm(Tensor x,
diff --git a/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs b/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs
index 3902d5d0..fab83d8a 100644
--- a/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs
+++ b/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs
@@ -199,7 +199,7 @@ namespace TensorFlowNET.Examples.ImageProcess
RefVariable layer_biases = null;
with(tf.name_scope("biases"), delegate
{
- layer_biases = tf.Variable(tf.zeros((class_count)), name: "final_biases");
+ layer_biases = tf.Variable(tf.zeros(class_count), name: "final_biases");
variable_summaries(layer_biases);
});