Browse Source

Merge pull request #1116 from lingbai-kong/imdbfix

fix: type converting errors when loading imdb dataset
tags/v0.110.0-LSTM-Model
Haiping GitHub 2 years ago
parent
commit
3de7b8e8ed
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 0 deletions
  1. +6
    -0
      src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs
  2. +11
    -0
      src/TensorFlowNET.Core/Operations/array_ops.cs

+ 6
- 0
src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs View File

@@ -107,9 +107,15 @@ namespace Tensorflow.NumPy
public static implicit operator NDArray(bool value) public static implicit operator NDArray(bool value)
=> new NDArray(value); => new NDArray(value);


public static implicit operator NDArray(byte value)
=> new NDArray(value);

public static implicit operator NDArray(int value) public static implicit operator NDArray(int value)
=> new NDArray(value); => new NDArray(value);


public static implicit operator NDArray(long value)
=> new NDArray(value);

public static implicit operator NDArray(float value) public static implicit operator NDArray(float value)
=> new NDArray(value); => new NDArray(value);




+ 11
- 0
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -84,8 +84,13 @@ namespace Tensorflow
// var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); // var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape);
Tensor zeros = dtype switch Tensor zeros = dtype switch
{ {
TF_DataType.TF_BOOL => constant(false),
TF_DataType.TF_DOUBLE => constant(0d), TF_DataType.TF_DOUBLE => constant(0d),
TF_DataType.TF_FLOAT => constant(0f), TF_DataType.TF_FLOAT => constant(0f),
TF_DataType.TF_INT64 => constant(0L),
TF_DataType.TF_UINT64 => constant((ulong)0),
TF_DataType.TF_INT32 => constant(0),
TF_DataType.TF_UINT32 => constant((uint)0),
TF_DataType.TF_INT8 => constant((sbyte)0), TF_DataType.TF_INT8 => constant((sbyte)0),
TF_DataType.TF_UINT8 => constant((byte)0), TF_DataType.TF_UINT8 => constant((byte)0),
_ => constant(0) _ => constant(0)
@@ -108,9 +113,15 @@ namespace Tensorflow
return _constant_if_small(0.0F, shape, dtype, name); return _constant_if_small(0.0F, shape, dtype, name);
case TF_DataType.TF_INT64: case TF_DataType.TF_INT64:
return _constant_if_small(0L, shape, dtype, name); return _constant_if_small(0L, shape, dtype, name);
case TF_DataType.TF_UINT64:
return _constant_if_small<ulong>(0, shape, dtype, name);
case TF_DataType.TF_INT32: case TF_DataType.TF_INT32:
return _constant_if_small(0, shape, dtype, name); return _constant_if_small(0, shape, dtype, name);
case TF_DataType.TF_UINT32:
return _constant_if_small<uint>(0, shape, dtype, name);
case TF_DataType.TF_INT8: case TF_DataType.TF_INT8:
return _constant_if_small<sbyte>(0, shape, dtype, name);
case TF_DataType.TF_UINT8:
return _constant_if_small<byte>(0, shape, dtype, name); return _constant_if_small<byte>(0, shape, dtype, name);
default: default:
throw new TypeError("can't find type for zeros"); throw new TypeError("can't find type for zeros");


Loading…
Cancel
Save