From 9aa8f758f402cb108fabd7be1c043996acadf859 Mon Sep 17 00:00:00 2001 From: Haiping Chen Date: Thu, 2 Mar 2023 21:57:48 -0600 Subject: [PATCH] Fix sparse_categorical_crossentropy. --- src/TensorFlowNET.Keras/BackendImpl.cs | 5 ++--- test/TensorFlowNET.Keras.UnitTest/MultiThreadsTest.cs | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs index c49fc140..01aa59b9 100644 --- a/src/TensorFlowNET.Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Keras/BackendImpl.cs @@ -307,9 +307,8 @@ namespace Tensorflow.Keras var update_shape = target_rank > -1 && output_rank > -1 && target_rank != output_rank - 1; if (update_shape) { - /*var target = flatten(target); - output = tf.reshape(output, [-1, output_shape[-1]]);*/ - throw new NotImplementedException(""); + target = tf.reshape(target, -1); + output = tf.reshape(output, (-1, output.shape[-1])); } if (ignore_class.HasValue) diff --git a/test/TensorFlowNET.Keras.UnitTest/MultiThreadsTest.cs b/test/TensorFlowNET.Keras.UnitTest/MultiThreadsTest.cs index ffe3d6f4..555154d7 100644 --- a/test/TensorFlowNET.Keras.UnitTest/MultiThreadsTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/MultiThreadsTest.cs @@ -8,7 +8,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; namespace TensorFlowNET.Keras.UnitTest { - [TestClass, Ignore] + [TestClass] public class MultiThreads { [TestMethod]