| @@ -170,7 +170,7 @@ namespace Tensorflow | |||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), v); | |||
| break; | |||
| case NDArray v: | |||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); | |||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), v); | |||
| break; | |||
| case IntPtr v: | |||
| var tensor = new Tensor(v); | |||
| @@ -179,38 +179,9 @@ namespace Tensorflow | |||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), tensor); | |||
| break; | |||
| // @formatter:off — disable formatter after this line | |||
| /*case bool v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||
| case bool[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||
| case sbyte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||
| case sbyte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||
| case byte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||
| case byte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||
| case short v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||
| case short[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||
| case ushort v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||
| case ushort[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||
| case int v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||
| case int[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||
| case uint v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||
| case uint[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||
| case long v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||
| case long[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||
| case ulong v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||
| case ulong[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||
| case float v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||
| case float[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||
| case double v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||
| case double[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||
| case Complex v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; | |||
| case Complex[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;*/ | |||
| // @formatter:on — enable formatter after this line | |||
| case string v: | |||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); | |||
| break; | |||
| default: | |||
| throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? "<null>"}"); | |||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), constant_op.constant(x.Value)); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| @@ -70,73 +70,18 @@ namespace Tensorflow | |||
| if (is_op) | |||
| { | |||
| if (tensor_values.Length > 0) | |||
| { | |||
| switch (tensor_values[0].dtype) | |||
| { | |||
| case NumpyDType.Int32: | |||
| full_values.Add(float.NaN); | |||
| break; | |||
| case NumpyDType.Single: | |||
| full_values.Add(float.NaN); | |||
| break; | |||
| case NumpyDType.Double: | |||
| full_values.Add(float.NaN); | |||
| break; | |||
| case NumpyDType.String: | |||
| full_values.Add(float.NaN); | |||
| break; | |||
| case NumpyDType.Char: | |||
| full_values.Add(float.NaN); | |||
| break; | |||
| case NumpyDType.Byte: | |||
| full_values.Add(float.NaN); | |||
| break; | |||
| default: | |||
| throw new NotImplementedException($"build_results tensor_values[0] {tensor_values[0].dtype}"); | |||
| } | |||
| } | |||
| full_values.Add(float.NaN); | |||
| else | |||
| { | |||
| full_values.Add(null); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| var value = tensor_values[j]; | |||
| j += 1; | |||
| if (value.ndim == 0) | |||
| { | |||
| switch (value.dtype) | |||
| { | |||
| case NumpyDType.Int16: | |||
| full_values.Add(value.GetValue<short>(0)); | |||
| break; | |||
| case NumpyDType.Int32: | |||
| full_values.Add(value.GetValue<int>(0)); | |||
| break; | |||
| case NumpyDType.Int64: | |||
| full_values.Add(value.GetValue<long>(0)); | |||
| break; | |||
| case NumpyDType.Single: | |||
| full_values.Add(value.GetValue<float>(0)); | |||
| break; | |||
| case NumpyDType.Double: | |||
| full_values.Add(value.GetValue<double>(0)); | |||
| break; | |||
| case NumpyDType.Boolean: | |||
| full_values.Add(value.GetValue<bool>(0)); | |||
| break; | |||
| /*case "String": | |||
| full_values.Add(value.Data<byte>()[0]); | |||
| break;*/ | |||
| default: | |||
| throw new NotImplementedException($"build_results tensor_values[0] {tensor_values[0].dtype}"); | |||
| } | |||
| } | |||
| full_values.Add(value); | |||
| else | |||
| { | |||
| full_values.Add(value[np.arange(0, (int)value.dims[0])]); | |||
| } | |||
| } | |||
| i += 1; | |||
| } | |||
| @@ -74,6 +74,7 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| // graph mode | |||
| Graph g = ops.get_default_graph(); | |||
| var tensor_value = new AttrValue(); | |||
| tensor_value.Tensor = tensor_util.make_tensor_proto(value, | |||
| @@ -116,21 +116,43 @@ namespace Tensorflow | |||
| return tp; | |||
| dtype = values.GetType().as_tf_dtype(); | |||
| // We first convert value to a numpy array or scalar. | |||
| var tensor_proto = new TensorProto | |||
| { | |||
| Dtype = dtype.as_datatype_enum(), | |||
| // TensorShape = tensor_util.as_shape(shape.dims) | |||
| }; | |||
| /*if (is_same_size && _TENSOR_CONTENT_TYPES.Contains(numpy_dtype) && shape_size > 1) | |||
| // scalar | |||
| if (!values.GetType().IsArray) | |||
| { | |||
| byte[] bytes = nparray.ToByteArray(); | |||
| tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes.ToArray()); | |||
| tensor_proto.TensorShape = tensor_util.as_shape(new int[0]); | |||
| switch (values) | |||
| { | |||
| case bool val: | |||
| tensor_proto.BoolVal.AddRange(new[] { val }); | |||
| break; | |||
| case int val: | |||
| tensor_proto.IntVal.AddRange(new[] { val }); | |||
| break; | |||
| case long val: | |||
| tensor_proto.Int64Val.AddRange(new[] { val }); | |||
| break; | |||
| case float val: | |||
| tensor_proto.FloatVal.AddRange(new[] { val }); | |||
| break; | |||
| case double val: | |||
| tensor_proto.DoubleVal.AddRange(new[] { val }); | |||
| break; | |||
| case string val: | |||
| tensor_proto.StringVal.AddRange(val.Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x.ToString()))); | |||
| break; | |||
| default: | |||
| throw new Exception("make_tensor_proto Not Implemented"); | |||
| } | |||
| return tensor_proto; | |||
| } | |||
| if (numpy_dtype == TF_DataType.TF_STRING && !(values is NDArray)) | |||
| else if (dtype == TF_DataType.TF_STRING && !(values is NDArray)) | |||
| { | |||
| if (values is string str) | |||
| { | |||
| @@ -144,33 +166,18 @@ namespace Tensorflow | |||
| return tensor_proto; | |||
| } | |||
| var proto_values = nparray.ravel();*/ | |||
| switch (values) | |||
| else | |||
| { | |||
| case float val: | |||
| tensor_proto.TensorShape = tensor_util.as_shape(new int[0]); | |||
| tensor_proto.FloatVal.AddRange(new[] { val }); | |||
| break; | |||
| /*case "Bool": | |||
| case "Boolean": | |||
| tensor_proto.BoolVal.AddRange(proto_values.Data<bool>()); | |||
| break; | |||
| case "Int32": | |||
| tensor_proto.IntVal.AddRange(proto_values.Data<int>()); | |||
| break; | |||
| case "Int64": | |||
| tensor_proto.Int64Val.AddRange(proto_values.Data<long>()); | |||
| break; | |||
| case "Double": | |||
| tensor_proto.DoubleVal.AddRange(proto_values.Data<double>()); | |||
| break; | |||
| case "String": | |||
| tensor_proto.StringVal.AddRange(proto_values.Data<string>().Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x.ToString()))); | |||
| break;*/ | |||
| default: | |||
| throw new Exception("make_tensor_proto Not Implemented"); | |||
| tensor_proto.TensorShape = tensor_util.as_shape(shape); | |||
| // array | |||
| if (_TENSOR_CONTENT_TYPES.Contains(dtype)) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| /*byte[] bytes = nparray.ToByteArray(); | |||
| tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes.ToArray()); | |||
| return tensor_proto;*/ | |||
| } | |||
| } | |||
| return tensor_proto; | |||
| @@ -46,7 +46,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||
| var o = sess.run(c, | |||
| new FeedItem(a, 3.0f), | |||
| new FeedItem(b, 2.0f)); | |||
| Assert.AreEqual((float)o, 5.0f); | |||
| Assert.AreEqual(o, 5.0f); | |||
| } | |||
| } | |||
| @@ -60,7 +60,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||
| using (var sess = tf.Session()) | |||
| { | |||
| var o = sess.run(c); | |||
| Assert.AreEqual((float)o, 9.0f); | |||
| Assert.AreEqual(o, 9.0f); | |||
| } | |||
| } | |||