diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index cecf89de..33f2a705 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -170,7 +170,7 @@ namespace Tensorflow feeds[i++] = new KeyValuePair(key._as_tf_output(), v); break; case NDArray v: - feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); + feeds[i++] = new KeyValuePair(key._as_tf_output(), v); break; case IntPtr v: var tensor = new Tensor(v); @@ -179,38 +179,9 @@ namespace Tensorflow feeds[i++] = new KeyValuePair(key._as_tf_output(), tensor); break; - // @formatter:off — disable formatter after this line - /*case bool v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; - case bool[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; - case sbyte v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; - case sbyte[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; - case byte v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; - case byte[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; - case short v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; - case short[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; - case ushort v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; - case ushort[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; - case int v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; - case int[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; - case uint v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; - case uint[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; - case long v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; - case long[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; - case ulong v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; - case ulong[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; - case float v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; - case float[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; - case double v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; - case double[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; - case Complex v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; - case Complex[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;*/ - // @formatter:on — enable formatter after this line - - case string v: - feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); - break; default: - throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? ""}"); + feeds[i++] = new KeyValuePair(key._as_tf_output(), constant_op.constant(x.Value)); + break; } } } diff --git a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs index 5d906506..0be6b395 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs @@ -70,73 +70,18 @@ namespace Tensorflow if (is_op) { if (tensor_values.Length > 0) - { - switch (tensor_values[0].dtype) - { - case NumpyDType.Int32: - full_values.Add(float.NaN); - break; - case NumpyDType.Single: - full_values.Add(float.NaN); - break; - case NumpyDType.Double: - full_values.Add(float.NaN); - break; - case NumpyDType.String: - full_values.Add(float.NaN); - break; - case NumpyDType.Char: - full_values.Add(float.NaN); - break; - case NumpyDType.Byte: - full_values.Add(float.NaN); - break; - default: - throw new NotImplementedException($"build_results tensor_values[0] {tensor_values[0].dtype}"); - } - } + full_values.Add(float.NaN); else - { full_values.Add(null); - } } else { var value = tensor_values[j]; j += 1; if (value.ndim == 0) - { - switch (value.dtype) - { - case NumpyDType.Int16: - full_values.Add(value.GetValue(0)); - break; - case NumpyDType.Int32: - full_values.Add(value.GetValue(0)); - break; - case NumpyDType.Int64: - full_values.Add(value.GetValue(0)); - break; - case NumpyDType.Single: - full_values.Add(value.GetValue(0)); - break; - case NumpyDType.Double: - full_values.Add(value.GetValue(0)); - break; - case NumpyDType.Boolean: - full_values.Add(value.GetValue(0)); - break; - /*case "String": - full_values.Add(value.Data()[0]); - break;*/ - default: - throw new NotImplementedException($"build_results tensor_values[0] {tensor_values[0].dtype}"); - } - } + full_values.Add(value); else - { full_values.Add(value[np.arange(0, (int)value.dims[0])]); - } } i += 1; } diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index ccc11f2a..1dc1a493 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -74,6 +74,7 @@ namespace Tensorflow } } + // graph mode Graph g = ops.get_default_graph(); var tensor_value = new AttrValue(); tensor_value.Tensor = tensor_util.make_tensor_proto(value, diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 91b6c9c3..056a5f37 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -116,21 +116,43 @@ namespace Tensorflow return tp; dtype = values.GetType().as_tf_dtype(); - // We first convert value to a numpy array or scalar. var tensor_proto = new TensorProto { Dtype = dtype.as_datatype_enum(), - // TensorShape = tensor_util.as_shape(shape.dims) }; - /*if (is_same_size && _TENSOR_CONTENT_TYPES.Contains(numpy_dtype) && shape_size > 1) + // scalar + if (!values.GetType().IsArray) { - byte[] bytes = nparray.ToByteArray(); - tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes.ToArray()); + tensor_proto.TensorShape = tensor_util.as_shape(new int[0]); + + switch (values) + { + case bool val: + tensor_proto.BoolVal.AddRange(new[] { val }); + break; + case int val: + tensor_proto.IntVal.AddRange(new[] { val }); + break; + case long val: + tensor_proto.Int64Val.AddRange(new[] { val }); + break; + case float val: + tensor_proto.FloatVal.AddRange(new[] { val }); + break; + case double val: + tensor_proto.DoubleVal.AddRange(new[] { val }); + break; + case string val: + tensor_proto.StringVal.AddRange(val.Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x.ToString()))); + break; + default: + throw new Exception("make_tensor_proto Not Implemented"); + } + return tensor_proto; } - - if (numpy_dtype == TF_DataType.TF_STRING && !(values is NDArray)) + else if (dtype == TF_DataType.TF_STRING && !(values is NDArray)) { if (values is string str) { @@ -144,33 +166,18 @@ namespace Tensorflow return tensor_proto; } - - var proto_values = nparray.ravel();*/ - - switch (values) + else { - case float val: - tensor_proto.TensorShape = tensor_util.as_shape(new int[0]); - tensor_proto.FloatVal.AddRange(new[] { val }); - break; - /*case "Bool": - case "Boolean": - tensor_proto.BoolVal.AddRange(proto_values.Data()); - break; - case "Int32": - tensor_proto.IntVal.AddRange(proto_values.Data()); - break; - case "Int64": - tensor_proto.Int64Val.AddRange(proto_values.Data()); - break; - case "Double": - tensor_proto.DoubleVal.AddRange(proto_values.Data()); - break; - case "String": - tensor_proto.StringVal.AddRange(proto_values.Data().Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x.ToString()))); - break;*/ - default: - throw new Exception("make_tensor_proto Not Implemented"); + tensor_proto.TensorShape = tensor_util.as_shape(shape); + + // array + if (_TENSOR_CONTENT_TYPES.Contains(dtype)) + { + throw new NotImplementedException(""); + /*byte[] bytes = nparray.ToByteArray(); + tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes.ToArray()); + return tensor_proto;*/ + } } return tensor_proto; diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index f0101151..08e1bc4a 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -46,7 +46,7 @@ namespace TensorFlowNET.UnitTest.Basics var o = sess.run(c, new FeedItem(a, 3.0f), new FeedItem(b, 2.0f)); - Assert.AreEqual((float)o, 5.0f); + Assert.AreEqual(o, 5.0f); } } @@ -60,7 +60,7 @@ namespace TensorFlowNET.UnitTest.Basics using (var sess = tf.Session()) { var o = sess.run(c); - Assert.AreEqual((float)o, 9.0f); + Assert.AreEqual(o, 9.0f); } }