From 9264a190022e78f5120aeee09a7daafa701ec267 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 11 Jul 2021 20:58:42 -0500 Subject: [PATCH] fix string data. --- src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs | 2 +- src/TensorFlowNET.Core/Tensors/Tensor.Value.cs | 8 ++++++-- src/TensorFlowNET.Core/ops.cs | 3 +++ 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs index 1f3064a8..825c0ac2 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs +++ b/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs @@ -37,7 +37,7 @@ namespace Tensorflow.NumPy => new NDArray(value); public static implicit operator Tensor(NDArray nd) - => constant_op.constant(nd); + => nd._tensor; public static implicit operator NDArray(Tensor tensor) => new NDArray(tensor); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs index cb88462b..ed72d9aa 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs @@ -49,8 +49,12 @@ namespace Tensorflow protected NDArray GetNDArray(TF_DataType dtype) { - /*if (dtype == TF_DataType.TF_STRING) - return np.array(StringData());*/ + if (dtype == TF_DataType.TF_STRING) + { + var str= StringData(); + return new NDArray(str, new Shape(str.Length)); + } + return new NDArray(this); } diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 07697d5f..e86c45b9 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -171,6 +171,9 @@ namespace Tensorflow _ => constant_op.constant(value, dtype: dtype, name: name) }; + if (dtype == TF_DataType.TF_STRING) + return ret; + var original_dtype = value.GetDataType(); if (dtype != TF_DataType.DtInvalid && dtype != original_dtype) ret = gen_math_ops.cast(ret, dtype.as_base_dtype(), name: name);