Browse Source

Tensor: renamed _dtype to _override_dtype

- Fixed all locations _dtype is used incorrectly.
tags/v0.12
Eli Belash 6 years ago
parent
commit
38f27f1641
4 changed files with 15 additions and 16 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  2. +2
    -2
      src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs
  3. +4
    -5
      src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
  4. +8
    -8
      src/TensorFlowNET.Core/Tensors/Tensor.cs

+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -594,7 +594,7 @@ namespace Tensorflow
{
_op = op;
_value_index = value_index;
_dtype = dtype;
_override_dtype = dtype;
_id = ops.uid();
}



+ 2
- 2
src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs View File

@@ -128,8 +128,8 @@ namespace Tensorflow
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void EnsureDType(Tensor tensor, TF_DataType @is)
{
if (tensor._dtype != @is)
throw new InvalidCastException($"Unable to cast scalar tensor {tensor._dtype} to {@is}");
if (tensor.dtype != @is)
throw new InvalidCastException($"Unable to cast scalar tensor {tensor.dtype} to {@is}");
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]


+ 4
- 5
src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs View File

@@ -69,11 +69,12 @@ namespace Tensorflow
TF_DataType.TF_QINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QINT32,
TF_DataType.TF_UINT8, TF_DataType.TF_UINT16, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64
};

public static Tensor operator /(double x, Tensor y) => BinaryOpWrapper("truediv", x, y);
public static Tensor operator /(float x, Tensor y) => BinaryOpWrapper("truediv", x, y);
public static Tensor operator /(int x, Tensor y) => BinaryOpWrapper("floordiv", x, y);
public static Tensor operator /(Tensor x, Tensor y) =>
_intTfDataTypes.Contains(x._dtype)
_intTfDataTypes.Contains(x.dtype)
? BinaryOpWrapper("floordiv", x, y)
: BinaryOpWrapper("truediv", x, y);
public static Tensor operator /(Tensor x, int y) => BinaryOpWrapper("floordiv", x, y);
@@ -122,8 +123,7 @@ namespace Tensorflow
if (y is Tensor tr)
dtype = tr.dtype.as_base_dtype();

var namescope = ops.name_scope(null, name, new { x, y });
return tf_with(namescope, scope =>
using (var scope = ops.name_scope(null, name, new { x, y }))
{
Tensor result = null;
var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x");
@@ -154,8 +154,7 @@ namespace Tensorflow
}

return result;
});

}
}
}
}

+ 8
- 8
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -43,7 +43,7 @@ namespace Tensorflow
private readonly Operation _op;
private readonly int _value_index;
private TF_Output? _tf_output;
private readonly TF_DataType _dtype;
private readonly TF_DataType _override_dtype;

public int Id => _id;

@@ -72,7 +72,7 @@ namespace Tensorflow
/// <summary>
/// The DType of elements in this tensor.
/// </summary>
public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle);
public TF_DataType dtype => _handle == IntPtr.Zero ? _override_dtype : c_api.TF_TensorType(_handle);

public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle);
public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype);
@@ -231,7 +231,7 @@ namespace Tensorflow
}

//Are the types matching?
if (typeof(T).as_dtype() == _dtype)
if (typeof(T).as_dtype() == dtype)
{
if (NDims == 0 && size == 1) //is it a scalar?
{
@@ -274,7 +274,7 @@ namespace Tensorflow
{
#if _REGEN
#region Compute
switch (_dtype.as_numpy_dtype().GetTypeCode())
switch (dtype.as_numpy_dtype().GetTypeCode())
{
%foreach supported_dtypes,supported_dtypes_lowercase%
case NPTypeCode.#1: return new T[] {Converts.ChangeType<T>(*(#2*) buffer, NPTypeCode.#1)};
@@ -286,7 +286,7 @@ namespace Tensorflow
#endregion
#else
#region Compute
switch (_dtype.as_numpy_dtype().GetTypeCode())
switch (dtype.as_numpy_dtype()?.GetTypeCode())
{
case NPTypeCode.Boolean: return new T[] {Converts.ChangeType<T>(*(bool*) buffer, NPTypeCode.Boolean)};
case NPTypeCode.Byte: return new T[] {Converts.ChangeType<T>(*(byte*) buffer, NPTypeCode.Byte)};
@@ -301,7 +301,7 @@ namespace Tensorflow
case NPTypeCode.Single: return new T[] {Converts.ChangeType<T>(*(float*) buffer, NPTypeCode.Single)};
case NPTypeCode.String: return new T[] {Converts.ChangeType<T>((string)this, NPTypeCode.String)};
default:
throw new NotSupportedException();
throw new NotSupportedException();
}
#endregion
#endif
@@ -318,7 +318,7 @@ namespace Tensorflow

#if _REGEN
#region Compute
switch (_dtype.as_numpy_dtype().GetTypeCode())
switch (dtype.as_numpy_dtype().GetTypeCode())
{
%foreach supported_dtypes,supported_dtypes_lowercase%
case NPTypeCode.#1: new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
@@ -329,7 +329,7 @@ namespace Tensorflow
#endregion
#else
#region Compute
switch (_dtype.as_numpy_dtype().GetTypeCode())
switch (dtype.as_numpy_dtype().GetTypeCode())
{
case NPTypeCode.Boolean: new UnmanagedMemoryBlock<bool>((bool*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Byte: new UnmanagedMemoryBlock<byte>((byte*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;


Loading…
Cancel
Save