diff --git a/src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs b/src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs index 68816963..74f1fe3e 100644 --- a/src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs +++ b/src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs @@ -35,7 +35,7 @@ namespace Tensorflow /// Python `str` prepended to names of ops created by this function. /// log_prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type `self.dtype`. - /* + public Tensor log_prob(Tensor value, string name = "log_prob") { return _call_log_prob(value, name); @@ -45,18 +45,39 @@ namespace Tensorflow { with(ops.name_scope(name, "moments", new { value }), scope => { - value = _convert_to_tensor(value, "value", _dtype); + try + { + return _log_prob(value); + } + catch (Exception e1) + { + try + { + return math_ops.log(_prob(value)); + } catch (Exception e2) + { + throw new NotImplementedException(); + } + } }); + return null; + } + private Tensor _log_prob(Tensor value) + { throw new NotImplementedException(); - } - private Tensor _convert_to_tensor(Tensor value, string name = null, TF_DataType preferred_dtype) + private Tensor _prob(Tensor value) { throw new NotImplementedException(); } - */ + + public TF_DataType dtype() + { + return this._dtype; + } + /// /// Constructs the `Distribution' diff --git a/src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs b/src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs index 6c77450a..e82b2ddd 100644 --- a/src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs +++ b/src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs @@ -1,3 +1,4 @@ +using System; using System.Collections.Generic; using Tensorflow; @@ -80,7 +81,7 @@ namespace Tensorflow private Tensor _log_prob(Tensor x) { - return _log_unnormalized_prob(_z(x)); + return _log_unnormalized_prob(_z(x)) -_log_normalization(); } private Tensor _log_unnormalized_prob (Tensor x) @@ -92,5 +93,11 @@ namespace Tensorflow { return (x - this._loc) / this._scale; } + + private Tensor _log_normalization() + { + Tensor t = new Tensor(Math.Log(2.0 * Math.PI)); + return 0.5 * t + math_ops.log(scale()); + } } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 3210c742..2f307f6d 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -48,6 +48,12 @@ namespace Tensorflow return _op.outputs[0]; } + /// + /// Computes square of x element-wise. + /// + /// A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`, `int32`, `int64`, `complex64`, `complex128`. + /// A name for the operation (optional). + /// A `Tensor`. Has the same type as `x`. public static Tensor square(Tensor x, string name = null) { var _op = _op_def_lib._apply_op_helper("Square", name, args: new { x }); @@ -55,6 +61,19 @@ namespace Tensorflow return _op.outputs[0]; } + /// + /// Computes natural logarithm of x element-wise. + /// + /// A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`, `complex64`, `complex128`. + /// name: A name for the operation (optional). + /// A `Tensor`. Has the same type as `x`. + public static Tensor log(Tensor x, string name = null) + { + var _op = _op_def_lib._apply_op_helper("Log", name, args: new { x }); + + return _op.outputs[0]; + } + public static Tensor cast(Tensor x, TF_DataType DstT, bool Truncate= false, string name= "") { var _op = _op_def_lib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate }); diff --git a/src/TensorFlowNET.Core/Operations/math_ops.py.cs b/src/TensorFlowNET.Core/Operations/math_ops.py.cs index a114234a..0909b187 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.py.cs @@ -60,6 +60,11 @@ namespace Tensorflow return gen_math_ops.square(x, name); } + public static Tensor log(Tensor x, string name = null) + { + return gen_math_ops.log(x, name); + } + /// /// Helper function for reduction ops. ///