From 0a08386ca95aaa5bc50cf581f97e5c611cdd5fcf Mon Sep 17 00:00:00 2001 From: Haiping Chen Date: Sat, 26 Nov 2022 16:01:14 -0600 Subject: [PATCH] Fix batch_size for Keras Input. --- src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs | 1 + src/TensorFlowNET.Keras/Layers/LayersApi.cs | 2 ++ src/python/xor_keras.py | 1 + 3 files changed, 4 insertions(+) diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs index 5945bb55..3f4d1ed8 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs @@ -108,6 +108,7 @@ namespace Tensorflow.Keras.Layers public ILayer GlobalMaxPooling2D(string data_format = "channels_last"); public Tensors Input(Shape shape, + int batch_size = -1, string name = null, bool sparse = false, bool ragged = false); diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index 8498f5ac..50c66be7 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -469,6 +469,7 @@ namespace Tensorflow.Keras.Layers /// /// A tensor. public Tensors Input(Shape shape, + int batch_size = -1, string name = null, bool sparse = false, bool ragged = false) @@ -476,6 +477,7 @@ namespace Tensorflow.Keras.Layers var input_layer = new InputLayer(new InputLayerArgs { InputShape = shape, + BatchSize= batch_size, Name = name, Sparse = sparse, Ragged = ragged diff --git a/src/python/xor_keras.py b/src/python/xor_keras.py index ffd88b61..e7388605 100644 --- a/src/python/xor_keras.py +++ b/src/python/xor_keras.py @@ -4,6 +4,7 @@ import tensorflow as tf os.environ["CUDA_VISIBLE_DEVICES"] = "-1" print(tf.__version__) +# https://playground.tensorflow.org/ # tf.compat.v1.enable_eager_execution() # tf.debugging.set_log_device_placement(True); tf.config.run_functions_eagerly(True)