| @@ -589,23 +589,17 @@ would not be rank 1.", tensor.op.get_attr("axis"))); | |||||
| { | { | ||||
| return "<unprintable>"; | return "<unprintable>"; | ||||
| } | } | ||||
| else if (dtype == TF_DataType.TF_RESOURCE) | |||||
| { | |||||
| return "<unprintable>"; | |||||
| } | |||||
| var nd = tensor.numpy(); | var nd = tensor.numpy(); | ||||
| if (nd.size == 0) | if (nd.size == 0) | ||||
| return "[]"; | return "[]"; | ||||
| switch (dtype) | |||||
| { | |||||
| case TF_DataType.TF_STRING: | |||||
| return string.Join(string.Empty, nd.ToArray<byte>() | |||||
| .Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString())); | |||||
| case TF_DataType.TF_VARIANT: | |||||
| case TF_DataType.TF_RESOURCE: | |||||
| return "<unprintable>"; | |||||
| default: | |||||
| return nd.ToString(); | |||||
| } | |||||
| return nd.ToString(); | |||||
| } | } | ||||
| public static ParsedSliceArgs ParseSlices(Slice[] slices) | public static ParsedSliceArgs ParseSlices(Slice[] slices) | ||||
| @@ -1,6 +1,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Keras.Losses; | |||||
| using Tensorflow.Keras.Metrics; | using Tensorflow.Keras.Metrics; | ||||
| using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
| @@ -74,11 +75,15 @@ namespace Tensorflow.Keras.Engine | |||||
| metric_obj = keras.metrics.sparse_categorical_accuracy; | metric_obj = keras.metrics.sparse_categorical_accuracy; | ||||
| else | else | ||||
| metric_obj = keras.metrics.categorical_accuracy; | metric_obj = keras.metrics.categorical_accuracy; | ||||
| return new MeanMetricWrapper(metric_obj, metric); | |||||
| } | } | ||||
| else if(metric == "mean_absolute_error" || metric == "mae") | |||||
| { | |||||
| metric_obj = keras.metrics.mean_absolute_error; | |||||
| } | |||||
| else | |||||
| throw new NotImplementedException(""); | |||||
| throw new NotImplementedException(""); | |||||
| return new MeanMetricWrapper(metric_obj, metric); | |||||
| } | } | ||||
| public IEnumerable<Metric> metrics | public IEnumerable<Metric> metrics | ||||
| @@ -40,5 +40,11 @@ namespace Tensorflow.Keras.Metrics | |||||
| return math_ops.cast(math_ops.equal(y_true, y_pred), TF_DataType.TF_FLOAT); | return math_ops.cast(math_ops.equal(y_true, y_pred), TF_DataType.TF_FLOAT); | ||||
| } | } | ||||
| public Tensor mean_absolute_error(Tensor y_true, Tensor y_pred) | |||||
| { | |||||
| y_true = math_ops.cast(y_true, y_pred.dtype); | |||||
| return keras.backend.mean(math_ops.abs(y_pred - y_true), axis: -1); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||