From 5511362d64930a9f6d1334342161fdb157629b2c Mon Sep 17 00:00:00 2001 From: dogvane Date: Wed, 4 Mar 2020 21:58:36 +0800 Subject: [PATCH] 1. fix miss type value. 2. fix type conver performance bug --- src/TensorFlowNET.Core/Tensors/TF_DataType.cs | 21 ++++++++++++++++++- src/TensorFlowNET.Core/Tensors/dtypes.cs | 18 ++++++---------- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/src/TensorFlowNET.Core/Tensors/TF_DataType.cs b/src/TensorFlowNET.Core/Tensors/TF_DataType.cs index c916b321..5fe28c5d 100644 --- a/src/TensorFlowNET.Core/Tensors/TF_DataType.cs +++ b/src/TensorFlowNET.Core/Tensors/TF_DataType.cs @@ -35,6 +35,25 @@ DtFloatRef = 101, // DT_FLOAT_REF DtDoubleRef = 102, // DT_DOUBLE_REF DtInt32Ref = 103, // DT_INT32_REF - DtInt64Ref = 109 // DT_INT64_REF + DtUint8Ref = 104, + DtInt16Ref = 105, + DtInt8Ref = 106, + DtStringRef = 107, + DtComplex64Ref = 108, + DtInt64Ref = 109, // DT_INT64_REF + DtBoolRef = 110, + DtQint8Ref = 111, + DtQuint8Ref = 112, + DtQint32Ref = 113, + DtBfloat16Ref = 114, + DtQint16Ref = 115, + DtQuint16Ref = 116, + DtUint16Ref = 117, + DtComplex128Ref = 118, + DtHalfRef = 119, + DtResourceRef = 120, + DtVariantRef = 121, + DtUint32Ref = 122, + DtUint64Ref = 123, } } diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index a54d0448..6c8b0385 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -46,7 +46,7 @@ namespace Tensorflow /// equivalent to , if none exists, returns null. public static Type as_numpy_dtype(this TF_DataType type) { - switch (type) + switch (type.as_base_dtype()) { case TF_DataType.TF_BOOL: return typeof(bool); @@ -182,14 +182,12 @@ namespace Tensorflow public static DataType as_datatype_enum(this TF_DataType type) { - return Enum.TryParse(((int) type).ToString(), out DataType dtype) ? dtype : DataType.DtInvalid; + return (DataType)type; } public static TF_DataType as_base_dtype(this TF_DataType type) { - return (int)type > 100 ? - (TF_DataType)Enum.Parse(typeof(TF_DataType), ((int)type - 100).ToString()) : - type; + return (int)type > 100 ? (TF_DataType)((int)type - 100) : type; } public static int name(this TF_DataType type) @@ -204,21 +202,17 @@ namespace Tensorflow public static DataType as_base_dtype(this DataType type) { - return (int)type > 100 ? - (DataType)Enum.Parse(typeof(DataType), ((int)type - 100).ToString()) : - type; + return (int)type > 100 ? (DataType)((int)type - 100) : type; } public static TF_DataType as_tf_dtype(this DataType type) { - return Enum.TryParse(((int) type).ToString(), out TF_DataType dtype) ? dtype : TF_DataType.DtInvalid; + return (TF_DataType)type; } public static TF_DataType as_ref(this TF_DataType type) { - return (int)type < 100 ? - (TF_DataType)Enum.Parse(typeof(TF_DataType), ((int)type + 100).ToString()) : - type; + return (int)type < 100 ? (TF_DataType)((int)type + 100) : type; } public static long max(this TF_DataType type)