Browse Source

fixt build_results.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
4c51d1b678
5 changed files with 48 additions and 124 deletions
  1. +3
    -32
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  2. +2
    -57
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  3. +1
    -0
      src/TensorFlowNET.Core/Tensors/constant_op.cs
  4. +40
    -33
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  5. +2
    -2
      test/TensorFlowNET.UnitTest/OperationsTest.cs

+ 3
- 32
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -170,7 +170,7 @@ namespace Tensorflow
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), v); feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), v);
break; break;
case NDArray v: case NDArray v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype));
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), v);
break; break;
case IntPtr v: case IntPtr v:
var tensor = new Tensor(v); var tensor = new Tensor(v);
@@ -179,38 +179,9 @@ namespace Tensorflow


feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), tensor); feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), tensor);
break; break;
// @formatter:off — disable formatter after this line
/*case bool v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case bool[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case sbyte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case sbyte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case byte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case byte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case short v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case short[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case ushort v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case ushort[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case int v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case int[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case uint v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case uint[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case long v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case long[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case ulong v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case ulong[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case float v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case float[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case double v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case double[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case Complex v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
case Complex[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;*/
// @formatter:on — enable formatter after this line

case string v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype));
break;
default: default:
throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? "<null>"}");
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), constant_op.constant(x.Value));
break;
} }
} }
} }


+ 2
- 57
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -70,73 +70,18 @@ namespace Tensorflow
if (is_op) if (is_op)
{ {
if (tensor_values.Length > 0) if (tensor_values.Length > 0)
{
switch (tensor_values[0].dtype)
{
case NumpyDType.Int32:
full_values.Add(float.NaN);
break;
case NumpyDType.Single:
full_values.Add(float.NaN);
break;
case NumpyDType.Double:
full_values.Add(float.NaN);
break;
case NumpyDType.String:
full_values.Add(float.NaN);
break;
case NumpyDType.Char:
full_values.Add(float.NaN);
break;
case NumpyDType.Byte:
full_values.Add(float.NaN);
break;
default:
throw new NotImplementedException($"build_results tensor_values[0] {tensor_values[0].dtype}");
}
}
full_values.Add(float.NaN);
else else
{
full_values.Add(null); full_values.Add(null);
}
} }
else else
{ {
var value = tensor_values[j]; var value = tensor_values[j];
j += 1; j += 1;
if (value.ndim == 0) if (value.ndim == 0)
{
switch (value.dtype)
{
case NumpyDType.Int16:
full_values.Add(value.GetValue<short>(0));
break;
case NumpyDType.Int32:
full_values.Add(value.GetValue<int>(0));
break;
case NumpyDType.Int64:
full_values.Add(value.GetValue<long>(0));
break;
case NumpyDType.Single:
full_values.Add(value.GetValue<float>(0));
break;
case NumpyDType.Double:
full_values.Add(value.GetValue<double>(0));
break;
case NumpyDType.Boolean:
full_values.Add(value.GetValue<bool>(0));
break;
/*case "String":
full_values.Add(value.Data<byte>()[0]);
break;*/
default:
throw new NotImplementedException($"build_results tensor_values[0] {tensor_values[0].dtype}");
}
}
full_values.Add(value);
else else
{
full_values.Add(value[np.arange(0, (int)value.dims[0])]); full_values.Add(value[np.arange(0, (int)value.dims[0])]);
}
} }
i += 1; i += 1;
} }


+ 1
- 0
src/TensorFlowNET.Core/Tensors/constant_op.cs View File

@@ -74,6 +74,7 @@ namespace Tensorflow
} }
} }


// graph mode
Graph g = ops.get_default_graph(); Graph g = ops.get_default_graph();
var tensor_value = new AttrValue(); var tensor_value = new AttrValue();
tensor_value.Tensor = tensor_util.make_tensor_proto(value, tensor_value.Tensor = tensor_util.make_tensor_proto(value,


+ 40
- 33
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -116,21 +116,43 @@ namespace Tensorflow
return tp; return tp;


dtype = values.GetType().as_tf_dtype(); dtype = values.GetType().as_tf_dtype();
// We first convert value to a numpy array or scalar.
var tensor_proto = new TensorProto var tensor_proto = new TensorProto
{ {
Dtype = dtype.as_datatype_enum(), Dtype = dtype.as_datatype_enum(),
// TensorShape = tensor_util.as_shape(shape.dims)
}; };


/*if (is_same_size && _TENSOR_CONTENT_TYPES.Contains(numpy_dtype) && shape_size > 1)
// scalar
if (!values.GetType().IsArray)
{ {
byte[] bytes = nparray.ToByteArray();
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes.ToArray());
tensor_proto.TensorShape = tensor_util.as_shape(new int[0]);

switch (values)
{
case bool val:
tensor_proto.BoolVal.AddRange(new[] { val });
break;
case int val:
tensor_proto.IntVal.AddRange(new[] { val });
break;
case long val:
tensor_proto.Int64Val.AddRange(new[] { val });
break;
case float val:
tensor_proto.FloatVal.AddRange(new[] { val });
break;
case double val:
tensor_proto.DoubleVal.AddRange(new[] { val });
break;
case string val:
tensor_proto.StringVal.AddRange(val.Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x.ToString())));
break;
default:
throw new Exception("make_tensor_proto Not Implemented");
}

return tensor_proto; return tensor_proto;
} }

if (numpy_dtype == TF_DataType.TF_STRING && !(values is NDArray))
else if (dtype == TF_DataType.TF_STRING && !(values is NDArray))
{ {
if (values is string str) if (values is string str)
{ {
@@ -144,33 +166,18 @@ namespace Tensorflow


return tensor_proto; return tensor_proto;
} }

var proto_values = nparray.ravel();*/
switch (values)
else
{ {
case float val:
tensor_proto.TensorShape = tensor_util.as_shape(new int[0]);
tensor_proto.FloatVal.AddRange(new[] { val });
break;
/*case "Bool":
case "Boolean":
tensor_proto.BoolVal.AddRange(proto_values.Data<bool>());
break;
case "Int32":
tensor_proto.IntVal.AddRange(proto_values.Data<int>());
break;
case "Int64":
tensor_proto.Int64Val.AddRange(proto_values.Data<long>());
break;
case "Double":
tensor_proto.DoubleVal.AddRange(proto_values.Data<double>());
break;
case "String":
tensor_proto.StringVal.AddRange(proto_values.Data<string>().Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x.ToString())));
break;*/
default:
throw new Exception("make_tensor_proto Not Implemented");
tensor_proto.TensorShape = tensor_util.as_shape(shape);

// array
if (_TENSOR_CONTENT_TYPES.Contains(dtype))
{
throw new NotImplementedException("");
/*byte[] bytes = nparray.ToByteArray();
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes.ToArray());
return tensor_proto;*/
}
} }


return tensor_proto; return tensor_proto;


+ 2
- 2
test/TensorFlowNET.UnitTest/OperationsTest.cs View File

@@ -46,7 +46,7 @@ namespace TensorFlowNET.UnitTest.Basics
var o = sess.run(c, var o = sess.run(c,
new FeedItem(a, 3.0f), new FeedItem(a, 3.0f),
new FeedItem(b, 2.0f)); new FeedItem(b, 2.0f));
Assert.AreEqual((float)o, 5.0f);
Assert.AreEqual(o, 5.0f);
} }
} }


@@ -60,7 +60,7 @@ namespace TensorFlowNET.UnitTest.Basics
using (var sess = tf.Session()) using (var sess = tf.Session())
{ {
var o = sess.run(c); var o = sess.run(c);
Assert.AreEqual((float)o, 9.0f);
Assert.AreEqual(o, 9.0f);
} }
} }




Loading…
Cancel
Save