From ec340eeff57c7f9bef8fc21dd94f17889b7453b5 Mon Sep 17 00:00:00 2001 From: Haiping Chen Date: Sat, 4 Feb 2023 12:06:21 -0600 Subject: [PATCH] np.ones_like and np.zeros_like --- src/TensorFlowNET.Console/SimpleRnnTest.cs | 1 - src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs | 9 ++++++--- src/TensorFlowNET.Core/Sessions/BaseSession.cs | 9 ++++++++- src/python/simple_rnn.py | 18 ++++++++++-------- .../Tensorflow.Keras.UnitTest.csproj | 1 - 5 files changed, 24 insertions(+), 14 deletions(-) diff --git a/src/TensorFlowNET.Console/SimpleRnnTest.cs b/src/TensorFlowNET.Console/SimpleRnnTest.cs index da124517..9769eb65 100644 --- a/src/TensorFlowNET.Console/SimpleRnnTest.cs +++ b/src/TensorFlowNET.Console/SimpleRnnTest.cs @@ -12,7 +12,6 @@ namespace Tensorflow { public void Run() { - tf.UseKeras(); var inputs = np.random.random((6, 10, 8)).astype(np.float32); //var simple_rnn = tf.keras.layers.SimpleRNN(4); //var output = simple_rnn.Apply(inputs); // The output has shape `[32, 4]`. diff --git a/src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs b/src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs index 7e6a2b65..9604392c 100644 --- a/src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs +++ b/src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs @@ -2,7 +2,6 @@ using System.Collections; using System.Collections.Generic; using System.IO; -using System.Numerics; using System.Text; using static Tensorflow.Binding; @@ -103,11 +102,15 @@ namespace Tensorflow.NumPy public static NDArray ones(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) => new NDArray(tf.ones(shape, dtype: dtype)); - public static NDArray ones_like(NDArray a, Type dtype = null) - => throw new NotImplementedException(""); + public static NDArray ones_like(NDArray a, TF_DataType dtype = TF_DataType.DtInvalid) + => new NDArray(tf.ones_like(a, dtype: dtype)); [AutoNumPy] public static NDArray zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) => new NDArray(tf.zeros(shape, dtype: dtype)); + + [AutoNumPy] + public static NDArray zeros_like(NDArray a, TF_DataType dtype = TF_DataType.DtInvalid) + => new NDArray(tf.zeros_like(a, dtype: dtype)); } } diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 0051a6b3..01ba0407 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -291,7 +291,14 @@ namespace Tensorflow protected override void DisposeUnmanagedResources(IntPtr handle) { // c_api.TF_CloseSession(handle, tf.Status.Handle); - c_api.TF_DeleteSession(handle, c_api.TF_NewStatus()); + if (tf.Status == null || tf.Status.Handle.IsInvalid) + { + c_api.TF_DeleteSession(handle, c_api.TF_NewStatus()); + } + else + { + c_api.TF_DeleteSession(handle, tf.Status.Handle); + } } } } diff --git a/src/python/simple_rnn.py b/src/python/simple_rnn.py index 97f9f3f3..c5f3b1f2 100644 --- a/src/python/simple_rnn.py +++ b/src/python/simple_rnn.py @@ -1,15 +1,17 @@ import numpy as np import tensorflow as tf +import tensorflow.experimental.numpy as tnp # tf.experimental.numpy -inputs = np.random.random([32, 10, 8]).astype(np.float32) -simple_rnn = tf.keras.layers.SimpleRNN(4) +inputs = np.arange(6 * 10 * 8).reshape([6, 10, 8]).astype(np.float32) +# simple_rnn = tf.keras.layers.SimpleRNN(4) -output = simple_rnn(inputs) # The output has shape `[32, 4]`. +# output = simple_rnn(inputs) # The output has shape `[6, 4]`. -simple_rnn = tf.keras.layers.SimpleRNN( - 4, return_sequences=True, return_state=True) +simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences=True, return_state=True) -# whole_sequence_output has shape `[32, 10, 4]`. -# final_state has shape `[32, 4]`. -whole_sequence_output, final_state = simple_rnn(inputs) \ No newline at end of file +# whole_sequence_output has shape `[6, 10, 4]`. +# final_state has shape `[6, 4]`. +whole_sequence_output, final_state = simple_rnn(inputs) +print(whole_sequence_output) +print(final_state) \ No newline at end of file diff --git a/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj b/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj index fc693b1e..61e522e6 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj +++ b/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj @@ -4,7 +4,6 @@ net6.0 false - 11.0 AnyCPU;x64