From 1c5731faf5cc4983eae395e83ccc285df2d1b2ef Mon Sep 17 00:00:00 2001 From: haiping008 Date: Tue, 5 Feb 2019 16:54:50 -0600 Subject: [PATCH] fix internal_convert_n_to_tensor return type. --- .../Operations/OpDefLibrary.cs | 25 +++++++------------ src/TensorFlowNET.Core/ops.py.cs | 10 ++++---- 2 files changed, 14 insertions(+), 21 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index 1ee00b15..e11a0e5f 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -84,32 +84,25 @@ namespace Tensorflow dtype = dtype.as_base_dtype(); values = ops.internal_convert_n_to_tensor(values, name: input_arg.Name, dtype: dtype, preferred_dtype: default_dtype, as_ref: input_arg.IsRef); - - inputs.AddRange(values as Tensor[]); } else { - if (!(values is Tensor)) + if (keywords[input_name] is Tensor) { - keywords[input_name] = constant_op.constant(values, input_name); } - - if (keywords[input_name] is Tensor value) + else { - if (keywords.ContainsKey(input_name)) - { - inputs.Add(value); - } - - if (!String.IsNullOrEmpty(input_arg.TypeAttr)) - { - attrs[input_arg.TypeAttr] = value.dtype; - } + keywords[input_name] = ops.internal_convert_to_tensor(values, name: input_name); + } - values = new Tensor[] { value }; + if (!String.IsNullOrEmpty(input_arg.TypeAttr)) + { + attrs[input_arg.TypeAttr] = (keywords[input_name] as Tensor).dtype; } + values = new Tensor[] { keywords[input_name] as Tensor }; } + inputs.AddRange(values as Tensor[]); base_types.AddRange((values as Tensor[]).Select(x => x.dtype.as_base_dtype())); input_types.AddRange(base_types); } diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index 5239efda..f72eb3bc 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -310,11 +310,11 @@ namespace Tensorflow }; } - public static T[] internal_convert_n_to_tensor(T[] values, DataType dtype = DataType.DtInvalid, + public static Tensor[] internal_convert_n_to_tensor(T[] values, DataType dtype = DataType.DtInvalid, string name = "", DataType preferred_dtype = DataType.DtInvalid, bool as_ref = false) { - var ret = new List(); + var ret = new List(); foreach((int i, T value) in Python.enumerate(values)) { @@ -325,16 +325,16 @@ namespace Tensorflow return ret.ToArray(); } - public static T internal_convert_to_tensor(T value, DataType dtype = DataType.DtInvalid, + public static Tensor internal_convert_to_tensor(T value, DataType dtype = DataType.DtInvalid, string name = "", DataType preferred_dtype = DataType.DtInvalid, bool as_ref = false) { switch (typeof(T).Name) { case "Tensor": - return value; + return value as Tensor; default: - throw new NotImplementedException("internal_convert_to_tensor"); + return constant_op.constant(np.array(value), name); } } }