Browse Source

Merge pull request #1116 from lingbai-kong/imdbfix

fix: type converting errors when loading imdb dataset
tags/v0.110.0-LSTM-Model
Haiping GitHub 2 years ago
parent
commit
3de7b8e8ed
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 0 deletions
  1. +6
    -0
      src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs
  2. +11
    -0
      src/TensorFlowNET.Core/Operations/array_ops.cs

+ 6
- 0
src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs View File

@@ -107,9 +107,15 @@ namespace Tensorflow.NumPy
public static implicit operator NDArray(bool value)
=> new NDArray(value);

public static implicit operator NDArray(byte value)
=> new NDArray(value);

public static implicit operator NDArray(int value)
=> new NDArray(value);

public static implicit operator NDArray(long value)
=> new NDArray(value);

public static implicit operator NDArray(float value)
=> new NDArray(value);



+ 11
- 0
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -84,8 +84,13 @@ namespace Tensorflow
// var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape);
Tensor zeros = dtype switch
{
TF_DataType.TF_BOOL => constant(false),
TF_DataType.TF_DOUBLE => constant(0d),
TF_DataType.TF_FLOAT => constant(0f),
TF_DataType.TF_INT64 => constant(0L),
TF_DataType.TF_UINT64 => constant((ulong)0),
TF_DataType.TF_INT32 => constant(0),
TF_DataType.TF_UINT32 => constant((uint)0),
TF_DataType.TF_INT8 => constant((sbyte)0),
TF_DataType.TF_UINT8 => constant((byte)0),
_ => constant(0)
@@ -108,9 +113,15 @@ namespace Tensorflow
return _constant_if_small(0.0F, shape, dtype, name);
case TF_DataType.TF_INT64:
return _constant_if_small(0L, shape, dtype, name);
case TF_DataType.TF_UINT64:
return _constant_if_small<ulong>(0, shape, dtype, name);
case TF_DataType.TF_INT32:
return _constant_if_small(0, shape, dtype, name);
case TF_DataType.TF_UINT32:
return _constant_if_small<uint>(0, shape, dtype, name);
case TF_DataType.TF_INT8:
return _constant_if_small<sbyte>(0, shape, dtype, name);
case TF_DataType.TF_UINT8:
return _constant_if_small<byte>(0, shape, dtype, name);
default:
throw new TypeError("can't find type for zeros");


Loading…
Cancel
Save