From 38f27f16418b55b6cba4937e1cb6dd31a5607c66 Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Wed, 21 Aug 2019 23:54:50 +0300 Subject: [PATCH] Tensor: renamed _dtype to _override_dtype - Fixed all locations _dtype is used incorrectly. --- .../Tensors/Tensor.Creation.cs | 2 +- .../Tensors/Tensor.Explicit.cs | 4 ++-- .../Tensors/Tensor.Operators.cs | 9 ++++----- src/TensorFlowNET.Core/Tensors/Tensor.cs | 16 ++++++++-------- 4 files changed, 15 insertions(+), 16 deletions(-) diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index ea58607b..73f116ec 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -594,7 +594,7 @@ namespace Tensorflow { _op = op; _value_index = value_index; - _dtype = dtype; + _override_dtype = dtype; _id = ops.uid(); } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs index 77603e49..6d7f20f1 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs @@ -128,8 +128,8 @@ namespace Tensorflow [MethodImpl(MethodImplOptions.AggressiveInlining)] private static void EnsureDType(Tensor tensor, TF_DataType @is) { - if (tensor._dtype != @is) - throw new InvalidCastException($"Unable to cast scalar tensor {tensor._dtype} to {@is}"); + if (tensor.dtype != @is) + throw new InvalidCastException($"Unable to cast scalar tensor {tensor.dtype} to {@is}"); } [MethodImpl(MethodImplOptions.AggressiveInlining)] diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs index eb912eb9..02d19e56 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs @@ -69,11 +69,12 @@ namespace Tensorflow TF_DataType.TF_QINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QINT32, TF_DataType.TF_UINT8, TF_DataType.TF_UINT16, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64 }; + public static Tensor operator /(double x, Tensor y) => BinaryOpWrapper("truediv", x, y); public static Tensor operator /(float x, Tensor y) => BinaryOpWrapper("truediv", x, y); public static Tensor operator /(int x, Tensor y) => BinaryOpWrapper("floordiv", x, y); public static Tensor operator /(Tensor x, Tensor y) => - _intTfDataTypes.Contains(x._dtype) + _intTfDataTypes.Contains(x.dtype) ? BinaryOpWrapper("floordiv", x, y) : BinaryOpWrapper("truediv", x, y); public static Tensor operator /(Tensor x, int y) => BinaryOpWrapper("floordiv", x, y); @@ -122,8 +123,7 @@ namespace Tensorflow if (y is Tensor tr) dtype = tr.dtype.as_base_dtype(); - var namescope = ops.name_scope(null, name, new { x, y }); - return tf_with(namescope, scope => + using (var scope = ops.name_scope(null, name, new { x, y })) { Tensor result = null; var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x"); @@ -154,8 +154,7 @@ namespace Tensorflow } return result; - }); - + } } } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 6f4cc21a..b23b8b98 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -43,7 +43,7 @@ namespace Tensorflow private readonly Operation _op; private readonly int _value_index; private TF_Output? _tf_output; - private readonly TF_DataType _dtype; + private readonly TF_DataType _override_dtype; public int Id => _id; @@ -72,7 +72,7 @@ namespace Tensorflow /// /// The DType of elements in this tensor. /// - public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle); + public TF_DataType dtype => _handle == IntPtr.Zero ? _override_dtype : c_api.TF_TensorType(_handle); public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); @@ -231,7 +231,7 @@ namespace Tensorflow } //Are the types matching? - if (typeof(T).as_dtype() == _dtype) + if (typeof(T).as_dtype() == dtype) { if (NDims == 0 && size == 1) //is it a scalar? { @@ -274,7 +274,7 @@ namespace Tensorflow { #if _REGEN #region Compute - switch (_dtype.as_numpy_dtype().GetTypeCode()) + switch (dtype.as_numpy_dtype().GetTypeCode()) { %foreach supported_dtypes,supported_dtypes_lowercase% case NPTypeCode.#1: return new T[] {Converts.ChangeType(*(#2*) buffer, NPTypeCode.#1)}; @@ -286,7 +286,7 @@ namespace Tensorflow #endregion #else #region Compute - switch (_dtype.as_numpy_dtype().GetTypeCode()) + switch (dtype.as_numpy_dtype()?.GetTypeCode()) { case NPTypeCode.Boolean: return new T[] {Converts.ChangeType(*(bool*) buffer, NPTypeCode.Boolean)}; case NPTypeCode.Byte: return new T[] {Converts.ChangeType(*(byte*) buffer, NPTypeCode.Byte)}; @@ -301,7 +301,7 @@ namespace Tensorflow case NPTypeCode.Single: return new T[] {Converts.ChangeType(*(float*) buffer, NPTypeCode.Single)}; case NPTypeCode.String: return new T[] {Converts.ChangeType((string)this, NPTypeCode.String)}; default: - throw new NotSupportedException(); + throw new NotSupportedException(); } #endregion #endif @@ -318,7 +318,7 @@ namespace Tensorflow #if _REGEN #region Compute - switch (_dtype.as_numpy_dtype().GetTypeCode()) + switch (dtype.as_numpy_dtype().GetTypeCode()) { %foreach supported_dtypes,supported_dtypes_lowercase% case NPTypeCode.#1: new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; @@ -329,7 +329,7 @@ namespace Tensorflow #endregion #else #region Compute - switch (_dtype.as_numpy_dtype().GetTypeCode()) + switch (dtype.as_numpy_dtype().GetTypeCode()) { case NPTypeCode.Boolean: new UnmanagedMemoryBlock((bool*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; case NPTypeCode.Byte: new UnmanagedMemoryBlock((byte*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break;