Browse Source

Tensor: renamed _dtype to _override_dtype

- Fixed all locations _dtype is used incorrectly.
tags/v0.12
Eli Belash 6 years ago
parent
commit
38f27f1641
4 changed files with 15 additions and 16 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  2. +2
    -2
      src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs
  3. +4
    -5
      src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
  4. +8
    -8
      src/TensorFlowNET.Core/Tensors/Tensor.cs

+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -594,7 +594,7 @@ namespace Tensorflow
{ {
_op = op; _op = op;
_value_index = value_index; _value_index = value_index;
_dtype = dtype;
_override_dtype = dtype;
_id = ops.uid(); _id = ops.uid();
} }




+ 2
- 2
src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs View File

@@ -128,8 +128,8 @@ namespace Tensorflow
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void EnsureDType(Tensor tensor, TF_DataType @is) private static void EnsureDType(Tensor tensor, TF_DataType @is)
{ {
if (tensor._dtype != @is)
throw new InvalidCastException($"Unable to cast scalar tensor {tensor._dtype} to {@is}");
if (tensor.dtype != @is)
throw new InvalidCastException($"Unable to cast scalar tensor {tensor.dtype} to {@is}");
} }


[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]


+ 4
- 5
src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs View File

@@ -69,11 +69,12 @@ namespace Tensorflow
TF_DataType.TF_QINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QINT32, TF_DataType.TF_QINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QINT32,
TF_DataType.TF_UINT8, TF_DataType.TF_UINT16, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64 TF_DataType.TF_UINT8, TF_DataType.TF_UINT16, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64
}; };

public static Tensor operator /(double x, Tensor y) => BinaryOpWrapper("truediv", x, y); public static Tensor operator /(double x, Tensor y) => BinaryOpWrapper("truediv", x, y);
public static Tensor operator /(float x, Tensor y) => BinaryOpWrapper("truediv", x, y); public static Tensor operator /(float x, Tensor y) => BinaryOpWrapper("truediv", x, y);
public static Tensor operator /(int x, Tensor y) => BinaryOpWrapper("floordiv", x, y); public static Tensor operator /(int x, Tensor y) => BinaryOpWrapper("floordiv", x, y);
public static Tensor operator /(Tensor x, Tensor y) => public static Tensor operator /(Tensor x, Tensor y) =>
_intTfDataTypes.Contains(x._dtype)
_intTfDataTypes.Contains(x.dtype)
? BinaryOpWrapper("floordiv", x, y) ? BinaryOpWrapper("floordiv", x, y)
: BinaryOpWrapper("truediv", x, y); : BinaryOpWrapper("truediv", x, y);
public static Tensor operator /(Tensor x, int y) => BinaryOpWrapper("floordiv", x, y); public static Tensor operator /(Tensor x, int y) => BinaryOpWrapper("floordiv", x, y);
@@ -122,8 +123,7 @@ namespace Tensorflow
if (y is Tensor tr) if (y is Tensor tr)
dtype = tr.dtype.as_base_dtype(); dtype = tr.dtype.as_base_dtype();


var namescope = ops.name_scope(null, name, new { x, y });
return tf_with(namescope, scope =>
using (var scope = ops.name_scope(null, name, new { x, y }))
{ {
Tensor result = null; Tensor result = null;
var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x"); var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x");
@@ -154,8 +154,7 @@ namespace Tensorflow
} }


return result; return result;
});

}
} }
} }
} }

+ 8
- 8
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -43,7 +43,7 @@ namespace Tensorflow
private readonly Operation _op; private readonly Operation _op;
private readonly int _value_index; private readonly int _value_index;
private TF_Output? _tf_output; private TF_Output? _tf_output;
private readonly TF_DataType _dtype;
private readonly TF_DataType _override_dtype;


public int Id => _id; public int Id => _id;


@@ -72,7 +72,7 @@ namespace Tensorflow
/// <summary> /// <summary>
/// The DType of elements in this tensor. /// The DType of elements in this tensor.
/// </summary> /// </summary>
public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle);
public TF_DataType dtype => _handle == IntPtr.Zero ? _override_dtype : c_api.TF_TensorType(_handle);


public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle);
public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype);
@@ -231,7 +231,7 @@ namespace Tensorflow
} }


//Are the types matching? //Are the types matching?
if (typeof(T).as_dtype() == _dtype)
if (typeof(T).as_dtype() == dtype)
{ {
if (NDims == 0 && size == 1) //is it a scalar? if (NDims == 0 && size == 1) //is it a scalar?
{ {
@@ -274,7 +274,7 @@ namespace Tensorflow
{ {
#if _REGEN #if _REGEN
#region Compute #region Compute
switch (_dtype.as_numpy_dtype().GetTypeCode())
switch (dtype.as_numpy_dtype().GetTypeCode())
{ {
%foreach supported_dtypes,supported_dtypes_lowercase% %foreach supported_dtypes,supported_dtypes_lowercase%
case NPTypeCode.#1: return new T[] {Converts.ChangeType<T>(*(#2*) buffer, NPTypeCode.#1)}; case NPTypeCode.#1: return new T[] {Converts.ChangeType<T>(*(#2*) buffer, NPTypeCode.#1)};
@@ -286,7 +286,7 @@ namespace Tensorflow
#endregion #endregion
#else #else
#region Compute #region Compute
switch (_dtype.as_numpy_dtype().GetTypeCode())
switch (dtype.as_numpy_dtype()?.GetTypeCode())
{ {
case NPTypeCode.Boolean: return new T[] {Converts.ChangeType<T>(*(bool*) buffer, NPTypeCode.Boolean)}; case NPTypeCode.Boolean: return new T[] {Converts.ChangeType<T>(*(bool*) buffer, NPTypeCode.Boolean)};
case NPTypeCode.Byte: return new T[] {Converts.ChangeType<T>(*(byte*) buffer, NPTypeCode.Byte)}; case NPTypeCode.Byte: return new T[] {Converts.ChangeType<T>(*(byte*) buffer, NPTypeCode.Byte)};
@@ -301,7 +301,7 @@ namespace Tensorflow
case NPTypeCode.Single: return new T[] {Converts.ChangeType<T>(*(float*) buffer, NPTypeCode.Single)}; case NPTypeCode.Single: return new T[] {Converts.ChangeType<T>(*(float*) buffer, NPTypeCode.Single)};
case NPTypeCode.String: return new T[] {Converts.ChangeType<T>((string)this, NPTypeCode.String)}; case NPTypeCode.String: return new T[] {Converts.ChangeType<T>((string)this, NPTypeCode.String)};
default: default:
throw new NotSupportedException();
throw new NotSupportedException();
} }
#endregion #endregion
#endif #endif
@@ -318,7 +318,7 @@ namespace Tensorflow


#if _REGEN #if _REGEN
#region Compute #region Compute
switch (_dtype.as_numpy_dtype().GetTypeCode())
switch (dtype.as_numpy_dtype().GetTypeCode())
{ {
%foreach supported_dtypes,supported_dtypes_lowercase% %foreach supported_dtypes,supported_dtypes_lowercase%
case NPTypeCode.#1: new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; case NPTypeCode.#1: new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
@@ -329,7 +329,7 @@ namespace Tensorflow
#endregion #endregion
#else #else
#region Compute #region Compute
switch (_dtype.as_numpy_dtype().GetTypeCode())
switch (dtype.as_numpy_dtype().GetTypeCode())
{ {
case NPTypeCode.Boolean: new UnmanagedMemoryBlock<bool>((bool*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; case NPTypeCode.Boolean: new UnmanagedMemoryBlock<bool>((bool*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Byte: new UnmanagedMemoryBlock<byte>((byte*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; case NPTypeCode.Byte: new UnmanagedMemoryBlock<byte>((byte*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;


Loading…
Cancel
Save