From 74b3cf74900c5bbd594c0f0a6391e9a47b20980f Mon Sep 17 00:00:00 2001 From: carb0n <58676303+carb0n@users.noreply.github.com> Date: Tue, 4 Aug 2020 22:12:06 -0400 Subject: [PATCH] implement math_ops.real, reduce_variance, reduce_std * implement math_ops.real, reduce_variance, reduce_std * add dtype.real_dtype() * add outward-facing api functions --- src/TensorFlowNET.Core/APIs/tf.math.cs | 9 +++ src/TensorFlowNET.Core/Operations/math_ops.cs | 56 +++++++++++++++++++ src/TensorFlowNET.Core/Tensors/dtypes.cs | 11 ++++ 3 files changed, 76 insertions(+) diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index 49bd7ca8..22e875cb 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -422,6 +422,9 @@ namespace Tensorflow public Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range") => math_ops.range(start, limit: limit, delta: delta, dtype: dtype, name: name); + public Tensor real(Tensor input, string name = null) + => math_ops.real(input, name); + /// /// Computes the "logical or" of elements across dimensions of a tensor. /// @@ -509,6 +512,12 @@ namespace Tensorflow public Tensor reduce_min(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) => math_ops.reduce_min(input_tensor, axis, keepdims, name); + public Tensor reduce_std(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) + => math_ops.reduce_std(input_tensor, axis, keepdims, name); + + public Tensor reduce_variance(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) + => math_ops.reduce_variance(input_tensor, axis, keepdims, name); + public Tensor sigmoid(T x, string name = null) => math_ops.sigmoid(x, name: name); diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index 36617a31..756b9a89 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -229,6 +229,22 @@ namespace Tensorflow public static Tensor mul_no_nan(Tx x, Ty y, string name = null) => gen_math_ops.mul_no_nan(x, y, name: name); + + public static Tensor real(Tensor input, string name = null) + { + using (var name_ = ops.name_scope(name, "Real", new [] {input})) + { + input = ops.convert_to_tensor(input, name: "input"); + if (input.dtype.is_complex()) + { + var real_dtype = input.dtype.real_dtype(); + return real(input, name: name); + } else + { + return input; + } + } + } /// /// Computes the mean of elements across dimensions of a tensor. @@ -295,6 +311,46 @@ namespace Tensorflow } } + public static Tensor reduce_std(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) + { + if (name == null) + name = "reduce_std"; + // else {name = name;} + + using (ops.name_scope(name)) + { + var variance = reduce_variance(input_tensor, axis: axis, keepdims: keepdims); + return gen_math_ops.sqrt(variance); + } + } + + public static Tensor reduce_variance(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) + { + if (name == null) + name = "reduce_variance"; + // else {name = name;} + + using (ops.name_scope(name)) + { + var means = reduce_mean(input_tensor, axis: axis, keepdims: true); + if (means.dtype.is_integer()) + throw new TypeError("Input must be either real or complex"); + var diff = input_tensor - means; + + Tensor squared_deviations; + if (diff.dtype.is_complex()) + { + var real_dtype = diff.dtype.real_dtype(); + squared_deviations = real( + gen_math_ops.mul(conj(diff), diff)); + } else + { + squared_deviations = gen_math_ops.square(diff); + } + return reduce_mean(squared_deviations, axis: axis, keepdims: keepdims); + } + } + public static Tensor sigmoid(T x, string name = null) => tf_with(ops.name_scope(name, "Sigmoid", x), scope => { diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index 2b03fa64..e32d0952 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -278,5 +278,16 @@ namespace Tensorflow { return self.as_datatype_enum() == other.as_datatype_enum(); } + + public static TF_DataType real_dtype(this TF_DataType self) + { + TF_DataType base_ = self.as_base_dtype(); + if (base_ == complex64) + return float32; + else if (base_ == complex128) + return float64; + else + return self; + } } }