From fcd10447abb20e50ed2d67e313c2f75566319649 Mon Sep 17 00:00:00 2001 From: lingbai-kong Date: Fri, 23 Jun 2023 13:39:36 +0800 Subject: [PATCH] add more type case for tensor.zeros --- src/TensorFlowNET.Core/Operations/array_ops.cs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index a0b47aac..24c39215 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -84,8 +84,13 @@ namespace Tensorflow // var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); Tensor zeros = dtype switch { + TF_DataType.TF_BOOL => constant(false), TF_DataType.TF_DOUBLE => constant(0d), TF_DataType.TF_FLOAT => constant(0f), + TF_DataType.TF_INT64 => constant(0L), + TF_DataType.TF_UINT64 => constant((ulong)0), + TF_DataType.TF_INT32 => constant(0), + TF_DataType.TF_UINT32 => constant((uint)0), TF_DataType.TF_INT8 => constant((sbyte)0), TF_DataType.TF_UINT8 => constant((byte)0), _ => constant(0) @@ -108,9 +113,15 @@ namespace Tensorflow return _constant_if_small(0.0F, shape, dtype, name); case TF_DataType.TF_INT64: return _constant_if_small(0L, shape, dtype, name); + case TF_DataType.TF_UINT64: + return _constant_if_small(0, shape, dtype, name); case TF_DataType.TF_INT32: return _constant_if_small(0, shape, dtype, name); + case TF_DataType.TF_UINT32: + return _constant_if_small(0, shape, dtype, name); case TF_DataType.TF_INT8: + return _constant_if_small(0, shape, dtype, name); + case TF_DataType.TF_UINT8: return _constant_if_small(0, shape, dtype, name); default: throw new TypeError("can't find type for zeros");