| @@ -2,13 +2,17 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.APIs | |||||
| namespace Tensorflow | |||||
| { | { | ||||
| public static partial class tf | public static partial class tf | ||||
| { | { | ||||
| public static class distributions | public static class distributions | ||||
| { | { | ||||
| public static Normal(Tensor loc, Tensor scale, bool validate_args = false, bool allow_nan_stats = true, string name = "Normal") => Normal(loc, scale, validate_args = false, allow_nan_stats = true, "Normal"); | |||||
| public static Normal Normal(Tensor loc, | |||||
| Tensor scale, | |||||
| bool validate_args = false, | |||||
| bool allow_nan_stats = true, | |||||
| string name = "Normal") => new Normal(loc, scale, validate_args = false, allow_nan_stats = true, "Normal"); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,7 +1,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Layers; | |||||
| using Tensorflow.Keras.Layers; | |||||
| using Tensorflow.Operations.Activation; | using Tensorflow.Operations.Activation; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -10,6 +10,8 @@ namespace Tensorflow | |||||
| public static Tensor sub(Tensor a, Tensor b) => gen_math_ops.sub(a, b); | public static Tensor sub(Tensor a, Tensor b) => gen_math_ops.sub(a, b); | ||||
| public static Tensor sqrt(Tensor a, string name = null) => gen_math_ops.sqrt(a, name); | |||||
| public static Tensor subtract<T>(Tensor x, T[] y, string name = null) where T : struct | public static Tensor subtract<T>(Tensor x, T[] y, string name = null) where T : struct | ||||
| => gen_math_ops.sub(x, ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"), name); | => gen_math_ops.sub(x, ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"), name); | ||||
| @@ -7,7 +7,7 @@ using System.Text; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| class _BaseDistribution : Python | |||||
| public class _BaseDistribution : Python | |||||
| { | { | ||||
| // Abstract base class needed for resolving subclass hierarchy. | // Abstract base class needed for resolving subclass hierarchy. | ||||
| } | } | ||||
| @@ -17,10 +17,10 @@ namespace Tensorflow | |||||
| /// Distribution is a base class for constructing and organizing properties | /// Distribution is a base class for constructing and organizing properties | ||||
| /// (e.g., mean, variance) of random variables (e.g, Bernoulli, Gaussian). | /// (e.g., mean, variance) of random variables (e.g, Bernoulli, Gaussian). | ||||
| /// </summary> | /// </summary> | ||||
| class Distribution : _BaseDistribution | |||||
| public class Distribution : _BaseDistribution | |||||
| { | { | ||||
| public TF_DataType _dtype {get;set;} | public TF_DataType _dtype {get;set;} | ||||
| public static ReparameterizationType _reparameterization_type {get;set;} | |||||
| //public ReparameterizationType _reparameterization_type {get;set;} | |||||
| public bool _validate_args {get;set;} | public bool _validate_args {get;set;} | ||||
| public bool _allow_nan_stats {get;set;} | public bool _allow_nan_stats {get;set;} | ||||
| public Dictionary<string, object> _parameters {get;set;} | public Dictionary<string, object> _parameters {get;set;} | ||||
| @@ -3,7 +3,7 @@ using Tensorflow; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| class Normal : Distribution | |||||
| public class Normal : Distribution | |||||
| { | { | ||||
| public Tensor _loc { get; set; } | public Tensor _loc { get; set; } | ||||
| public Tensor _scale { get; set; } | public Tensor _scale { get; set; } | ||||
| @@ -25,7 +25,7 @@ namespace Tensorflow | |||||
| /// <param name="validate_args"></param> | /// <param name="validate_args"></param> | ||||
| /// <param name="allow_nan_stats"></param> | /// <param name="allow_nan_stats"></param> | ||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| Normal (Tensor loc, Tensor scale, bool validate_args=false, bool allow_nan_stats=true, string name="Normal") | |||||
| public Normal (Tensor loc, Tensor scale, bool validate_args=false, bool allow_nan_stats=true, string name="Normal") | |||||
| { | { | ||||
| parameters.Add("name", name); | parameters.Add("name", name); | ||||
| parameters.Add("loc", loc); | parameters.Add("loc", loc); | ||||
| @@ -62,6 +62,13 @@ namespace Tensorflow | |||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| } | } | ||||
| public static Tensor sqrt(Tensor x, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("Sqrt", name, args: new { x }); | |||||
| return _op.outputs[0]; | |||||
| } | |||||
| public static Tensor sub<Tx, Ty>(Tx x, Ty y, string name = null) | public static Tensor sub<Tx, Ty>(Tx x, Ty y, string name = null) | ||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("Sub", name, args: new { x, y }); | var _op = _op_def_lib._apply_op_helper("Sub", name, args: new { x, y }); | ||||
| @@ -74,7 +74,7 @@ namespace TensorFlowNET.Examples | |||||
| // Create a 3x2 univariate normal distribution with the | // Create a 3x2 univariate normal distribution with the | ||||
| // Known mean and variance | // Known mean and variance | ||||
| var dist = tf.distributions.Normal(loc=mean, scale=tf.sqrt(variance)); | |||||
| var dist = tf.distributions.Normal(mean, tf.sqrt(variance)); | |||||
| } | } | ||||