| @@ -0,0 +1,13 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public static partial class tf | |||||
| { | |||||
| public static Tensor exp(Tensor x, | |||||
| string name = null) => gen_math_ops.exp(x, name); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,15 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public static partial class tf | |||||
| { | |||||
| public static Tensor reduce_logsumexp(Tensor input_tensor, | |||||
| int[] axis = null, | |||||
| bool keepdims = false, | |||||
| string name = null) => math_ops.reduce_logsumexp(input_tensor, axis, keepdims, name); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,14 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public static partial class tf | |||||
| { | |||||
| public static Tensor reshape(Tensor tensor, | |||||
| Tensor shape, | |||||
| string name = null) => gen_array_ops.reshape(tensor, shape, name); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,14 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public static partial class tf | |||||
| { | |||||
| public static Tensor tile(Tensor input, | |||||
| Tensor multiples, | |||||
| string name = null) => gen_array_ops.tile(input, multiples, name); | |||||
| } | |||||
| } | |||||
| @@ -35,7 +35,7 @@ namespace Tensorflow | |||||
| /// <param name="name"> Python `str` prepended to names of ops created by this function.</param> | /// <param name="name"> Python `str` prepended to names of ops created by this function.</param> | ||||
| /// <returns>log_prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type `self.dtype`.</returns> | /// <returns>log_prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type `self.dtype`.</returns> | ||||
| /* | |||||
| public Tensor log_prob(Tensor value, string name = "log_prob") | public Tensor log_prob(Tensor value, string name = "log_prob") | ||||
| { | { | ||||
| return _call_log_prob(value, name); | return _call_log_prob(value, name); | ||||
| @@ -45,18 +45,39 @@ namespace Tensorflow | |||||
| { | { | ||||
| with(ops.name_scope(name, "moments", new { value }), scope => | with(ops.name_scope(name, "moments", new { value }), scope => | ||||
| { | { | ||||
| value = _convert_to_tensor(value, "value", _dtype); | |||||
| try | |||||
| { | |||||
| return _log_prob(value); | |||||
| } | |||||
| catch (Exception e1) | |||||
| { | |||||
| try | |||||
| { | |||||
| return math_ops.log(_prob(value)); | |||||
| } catch (Exception e2) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | |||||
| }); | }); | ||||
| return null; | |||||
| } | |||||
| private Tensor _log_prob(Tensor value) | |||||
| { | |||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| } | } | ||||
| private Tensor _convert_to_tensor(Tensor value, string name = null, TF_DataType preferred_dtype) | |||||
| private Tensor _prob(Tensor value) | |||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| } | } | ||||
| */ | |||||
| public TF_DataType dtype() | |||||
| { | |||||
| return this._dtype; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Constructs the `Distribution' | /// Constructs the `Distribution' | ||||
| @@ -1,3 +1,4 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| @@ -80,7 +81,7 @@ namespace Tensorflow | |||||
| private Tensor _log_prob(Tensor x) | private Tensor _log_prob(Tensor x) | ||||
| { | { | ||||
| return _log_unnormalized_prob(_z(x)); | |||||
| return _log_unnormalized_prob(_z(x)) -_log_normalization(); | |||||
| } | } | ||||
| private Tensor _log_unnormalized_prob (Tensor x) | private Tensor _log_unnormalized_prob (Tensor x) | ||||
| @@ -92,5 +93,11 @@ namespace Tensorflow | |||||
| { | { | ||||
| return (x - this._loc) / this._scale; | return (x - this._loc) / this._scale; | ||||
| } | } | ||||
| private Tensor _log_normalization() | |||||
| { | |||||
| Tensor t = new Tensor(Math.Log(2.0 * Math.PI)); | |||||
| return 0.5 * t + math_ops.log(scale()); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -66,6 +66,11 @@ namespace Tensorflow | |||||
| public static Tensor ones_like<T>(T tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) | public static Tensor ones_like<T>(T tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) | ||||
| => ones_like_impl(tensor, dtype, name, optimize); | => ones_like_impl(tensor, dtype, name, optimize); | ||||
| public static Tensor reshape(Tensor tensor, Tensor shape, string name = null) | |||||
| { | |||||
| return gen_array_ops.reshape(tensor, shape, null); | |||||
| } | |||||
| private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true) | private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true) | ||||
| { | { | ||||
| return with(ops.name_scope(name, "ones_like", new { tensor }), scope => | return with(ops.name_scope(name, "ones_like", new { tensor }), scope => | ||||
| @@ -48,6 +48,58 @@ namespace Tensorflow | |||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Computes square of x element-wise. | |||||
| /// </summary> | |||||
| /// <param name="x"> A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`, `int32`, `int64`, `complex64`, `complex128`.</param> | |||||
| /// <param name="name"> A name for the operation (optional).</param> | |||||
| /// <returns> A `Tensor`. Has the same type as `x`.</returns> | |||||
| public static Tensor square(Tensor x, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("Square", name, args: new { x }); | |||||
| return _op.outputs[0]; | |||||
| } | |||||
| /// <summary> | |||||
| /// Returns which elements of x are finite. | |||||
| /// </summary> | |||||
| /// <param name="x"> A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`.</param> | |||||
| /// <param name="name"> A name for the operation (optional).</param> | |||||
| /// <returns> A `Tensor` of type `bool`.</returns> | |||||
| public static Tensor is_finite(Tensor x, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("IsFinite", name, args: new { x }); | |||||
| return _op.outputs[0]; | |||||
| } | |||||
| /// <summary> | |||||
| /// Computes exponential of x element-wise. \\(y = e^x\\). | |||||
| /// </summary> | |||||
| /// <param name="x"> A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`, `complex64`, `complex128`.</param> | |||||
| /// <param name="name"> A name for the operation (optional).</param> | |||||
| /// <returns> A `Tensor`. Has the same type as `x`.</returns> | |||||
| public static Tensor exp(Tensor x, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("Exp", name, args: new { x }); | |||||
| return _op.outputs[0]; | |||||
| } | |||||
| /// <summary> | |||||
| /// Computes natural logarithm of x element-wise. | |||||
| /// </summary> | |||||
| /// <param name="x"> A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`, `complex64`, `complex128`.</param> | |||||
| /// <param name="name"> name: A name for the operation (optional).</param> | |||||
| /// <returns> A `Tensor`. Has the same type as `x`.</returns> | |||||
| public static Tensor log(Tensor x, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("Log", name, args: new { x }); | |||||
| return _op.outputs[0]; | |||||
| } | |||||
| public static Tensor cast(Tensor x, TF_DataType DstT, bool Truncate= false, string name= "") | public static Tensor cast(Tensor x, TF_DataType DstT, bool Truncate= false, string name= "") | ||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate }); | var _op = _op_def_lib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate }); | ||||
| @@ -134,6 +186,13 @@ namespace Tensorflow | |||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| } | } | ||||
| public static Tensor _max(Tensor input, int[] axis, bool keep_dims=false, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("Max", name, new { input, reduction_indices = axis, keep_dims }); | |||||
| return _op.outputs[0]; | |||||
| } | |||||
| public static Tensor pow<Tx, Ty>(Tx x, Ty y, string name = null) | public static Tensor pow<Tx, Ty>(Tx x, Ty y, string name = null) | ||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("Pow", name, args: new { x, y }); | var _op = _op_def_lib._apply_op_helper("Pow", name, args: new { x, y }); | ||||
| @@ -57,7 +57,12 @@ namespace Tensorflow | |||||
| public static Tensor square(Tensor x, string name = null) | public static Tensor square(Tensor x, string name = null) | ||||
| { | { | ||||
| throw new NotImplementedException(); | |||||
| return gen_math_ops.square(x, name); | |||||
| } | |||||
| public static Tensor log(Tensor x, string name = null) | |||||
| { | |||||
| return gen_math_ops.log(x, name); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -82,6 +87,51 @@ namespace Tensorflow | |||||
| return gen_data_flow_ops.dynamic_stitch(a1, a2); | return gen_data_flow_ops.dynamic_stitch(a1, a2); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Computes log(sum(exp(elements across dimensions of a tensor))). | |||||
| /// Reduces `input_tensor` along the dimensions given in `axis`. | |||||
| /// Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each | |||||
| /// entry in `axis`. If `keepdims` is true, the reduced dimensions | |||||
| /// are retained with length 1. | |||||
| /// If `axis` has no entries, all dimensions are reduced, and a | |||||
| /// tensor with a single element is returned. | |||||
| /// This function is more numerically stable than log(sum(exp(input))). It avoids | |||||
| /// overflows caused by taking the exp of large inputs and underflows caused by | |||||
| /// taking the log of small inputs. | |||||
| /// </summary> | |||||
| /// <param name="input_tensor"> The tensor to reduce. Should have numeric type.</param> | |||||
| /// <param name="axis"> The dimensions to reduce. If `None` (the default), reduces all | |||||
| /// dimensions.Must be in the range `[-rank(input_tensor), rank(input_tensor))`.</param> | |||||
| /// <param name="keepdims"></param> | |||||
| /// <returns> The reduced tensor.</returns> | |||||
| public static Tensor reduce_logsumexp(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) | |||||
| { | |||||
| with(ops.name_scope(name, "ReduceLogSumExp", new { input_tensor }), scope => | |||||
| { | |||||
| var raw_max = reduce_max(input_tensor, axis, true); | |||||
| var my_max = array_ops.stop_gradient(array_ops.where(gen_math_ops.is_finite(raw_max), raw_max, array_ops.zeros_like(raw_max))); | |||||
| var result = gen_math_ops.log( | |||||
| reduce_sum( | |||||
| gen_math_ops.exp(gen_math_ops.sub(input_tensor, my_max)), | |||||
| new Tensor(axis), | |||||
| keepdims)); | |||||
| if (!keepdims) | |||||
| { | |||||
| my_max = array_ops.reshape(my_max, array_ops.shape(result)); | |||||
| } | |||||
| result = gen_math_ops.add(result, my_max); | |||||
| return _may_reduce_to_scalar(keepdims, axis, result); | |||||
| }); | |||||
| return null; | |||||
| } | |||||
| public static Tensor reduce_max(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) | |||||
| { | |||||
| return _may_reduce_to_scalar(keepdims, axis, gen_math_ops._max(input_tensor, (int[])_ReductionDims(input_tensor, axis), keepdims, name)); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Casts a tensor to type `int32`. | /// Casts a tensor to type `int32`. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -12,6 +12,7 @@ namespace TensorFlowNET.Examples | |||||
| /// </summary> | /// </summary> | ||||
| public class NaiveBayesClassifier : Python, IExample | public class NaiveBayesClassifier : Python, IExample | ||||
| { | { | ||||
| public Normal dist { get; set; } | |||||
| public void Run() | public void Run() | ||||
| { | { | ||||
| np.array<float>(1.0f, 1.0f); | np.array<float>(1.0f, 1.0f); | ||||
| @@ -72,16 +73,34 @@ namespace TensorFlowNET.Examples | |||||
| // Create a 3x2 univariate normal distribution with the | // Create a 3x2 univariate normal distribution with the | ||||
| // Known mean and variance | // Known mean and variance | ||||
| var dist = tf.distributions.Normal(mean, tf.sqrt(variance)); | var dist = tf.distributions.Normal(mean, tf.sqrt(variance)); | ||||
| this.dist = dist; | |||||
| } | } | ||||
| public void predict (NDArray X) | |||||
| public Tensor predict (NDArray X) | |||||
| { | { | ||||
| // assert self.dist is not None | |||||
| // nb_classes, nb_features = map(int, self.dist.scale.shape) | |||||
| if (dist == null) | |||||
| { | |||||
| throw new ArgumentNullException("cant not find the model (normal distribution)!"); | |||||
| } | |||||
| int nb_classes = (int) dist.scale().shape[0]; | |||||
| int nb_features = (int)dist.scale().shape[1]; | |||||
| // Conditional probabilities log P(x|c) with shape | |||||
| // (nb_samples, nb_classes) | |||||
| Tensor tile = tf.tile(new Tensor(X), new Tensor(new int[] { -1, nb_classes, nb_features })); | |||||
| Tensor r = tf.reshape(tile, new Tensor(new int[] { -1, nb_classes, nb_features })); | |||||
| var cond_probs = tf.reduce_sum(dist.log_prob(r)); | |||||
| // uniform priors | |||||
| var priors = np.log(np.array<double>((1.0 / nb_classes) * nb_classes)); | |||||
| // posterior log probability, log P(c) + log P(x|c) | |||||
| var joint_likelihood = tf.add(new Tensor(priors), cond_probs); | |||||
| // normalize to get (log)-probabilities | |||||
| throw new NotFiniteNumberException(); | |||||
| var norm_factor = tf.reduce_logsumexp(joint_likelihood, new int[] { 1 }, true); | |||||
| var log_prob = joint_likelihood - norm_factor; | |||||
| // exp to get the actual probabilities | |||||
| return tf.exp(log_prob); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||