| @@ -51,6 +51,9 @@ namespace Tensorflow | |||||
| public Tensor inv(Tensor input, bool adjoint = false, string name = null) | public Tensor inv(Tensor input, bool adjoint = false, string name = null) | ||||
| => ops.matrix_inverse(input, adjoint: adjoint, name: name); | => ops.matrix_inverse(input, adjoint: adjoint, name: name); | ||||
| public Tensor global_norm(Tensor[] t_list, string name = null) | |||||
| => clip_ops.global_norm(t_list, name: name); | |||||
| public Tensor lstsq(Tensor matrix, Tensor rhs, | public Tensor lstsq(Tensor matrix, Tensor rhs, | ||||
| NDArray l2_regularizer = null, bool fast = true, string name = null) | NDArray l2_regularizer = null, bool fast = true, string name = null) | ||||
| => ops.matrix_solve_ls(matrix, rhs, l2_regularizer: l2_regularizer, fast: fast, name: name); | => ops.matrix_solve_ls(matrix, rhs, l2_regularizer: l2_regularizer, fast: fast, name: name); | ||||
| @@ -14,6 +14,7 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Linq; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -36,5 +37,24 @@ namespace Tensorflow | |||||
| return t_max; | return t_max; | ||||
| }); | }); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Computes the global norm of multiple tensors. | |||||
| /// </summary> | |||||
| /// <param name="t_list"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor global_norm(Tensor[] t_list, string name = null) | |||||
| { | |||||
| return tf_with(ops.name_scope(name, "global_norm", t_list), delegate | |||||
| { | |||||
| var half_squared_norms = t_list.Select(v => nn_ops.l2_loss(v)).ToArray(); | |||||
| var half_squared_norm = math_ops.reduce_sum(array_ops.stack(half_squared_norms)); | |||||
| var norm = math_ops.sqrt(half_squared_norm * | |||||
| constant_op.constant(2.0, dtype: half_squared_norm.dtype), | |||||
| name: "global_norm"); | |||||
| return norm; | |||||
| }); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -75,11 +75,19 @@ namespace Tensorflow.Keras.Engine | |||||
| metric_obj = keras.metrics.sparse_categorical_accuracy; | metric_obj = keras.metrics.sparse_categorical_accuracy; | ||||
| else | else | ||||
| metric_obj = keras.metrics.categorical_accuracy; | metric_obj = keras.metrics.categorical_accuracy; | ||||
| metric = "accuracy"; | |||||
| } | } | ||||
| else if(metric == "mean_absolute_error" || metric == "mae") | else if(metric == "mean_absolute_error" || metric == "mae") | ||||
| { | |||||
| metric_obj = keras.metrics.mean_absolute_error; | metric_obj = keras.metrics.mean_absolute_error; | ||||
| metric = "mean_absolute_error"; | |||||
| } | |||||
| else if (metric == "mean_absolute_percentage_error" || metric == "mape") | else if (metric == "mean_absolute_percentage_error" || metric == "mape") | ||||
| { | |||||
| metric_obj = keras.metrics.mean_absolute_percentage_error; | metric_obj = keras.metrics.mean_absolute_percentage_error; | ||||
| metric = "mean_absolute_percentage_error"; | |||||
| } | |||||
| else | else | ||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| @@ -4,7 +4,6 @@ namespace Tensorflow.Keras.Metrics | |||||
| { | { | ||||
| public class MeanMetricWrapper : Mean | public class MeanMetricWrapper : Mean | ||||
| { | { | ||||
| string name; | |||||
| Func<Tensor, Tensor, Tensor> _fn = null; | Func<Tensor, Tensor, Tensor> _fn = null; | ||||
| public MeanMetricWrapper(Func<Tensor, Tensor, Tensor> fn, string name, TF_DataType dtype = TF_DataType.TF_FLOAT) | public MeanMetricWrapper(Func<Tensor, Tensor, Tensor> fn, string name, TF_DataType dtype = TF_DataType.TF_FLOAT) | ||||
| @@ -1,4 +1,5 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using Tensorflow; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest.ManagedAPI | namespace TensorFlowNET.UnitTest.ManagedAPI | ||||
| @@ -54,5 +55,13 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||||
| var e = tf.linalg.einsum("ij,jk->ik", (m0, m1)); | var e = tf.linalg.einsum("ij,jk->ik", (m0, m1)); | ||||
| Assert.AreEqual(e.shape, (2, 5)); | Assert.AreEqual(e.shape, (2, 5)); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void GlobalNorm() | |||||
| { | |||||
| var t_list = new Tensors(tf.constant(new float[] { 1, 2, 3, 4 }), tf.constant(new float[] { 5, 6, 7, 8 })); | |||||
| var norm = tf.linalg.global_norm(t_list); | |||||
| Assert.AreEqual(norm.numpy(), 14.282857f); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||