Browse Source

fix tensor creation from n-dim array.

tags/v0.20
Oceania2018 5 years ago
parent
commit
7ec1422588
2 changed files with 28 additions and 6 deletions
  1. +16
    -2
      src/TensorFlowNET.Core/Tensors/constant_op.cs
  2. +12
    -4
      test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs

+ 16
- 2
src/TensorFlowNET.Core/Tensors/constant_op.cs View File

@@ -142,18 +142,32 @@ namespace Tensorflow
return new EagerTensor(val, ctx.device_name); return new EagerTensor(val, ctx.device_name);
case int[,] val: case int[,] val:
return new EagerTensor(val, ctx.device_name); return new EagerTensor(val, ctx.device_name);
case int[,,] val:
return new EagerTensor(val, ctx.device_name);
case long val: case long val:
return new EagerTensor(val, ctx.device_name); return new EagerTensor(val, ctx.device_name);
case long[] val:
return new EagerTensor(val, ctx.device_name);
case long[,] val:
return new EagerTensor(val, ctx.device_name);
case long[,,] val:
return new EagerTensor(val, ctx.device_name);
case float val: case float val:
return new EagerTensor(val, ctx.device_name); return new EagerTensor(val, ctx.device_name);
case float[] val:
return new EagerTensor(val, ctx.device_name);
case float[,] val: case float[,] val:
return new EagerTensor(val, ctx.device_name); return new EagerTensor(val, ctx.device_name);
case double val:
case float[,,] val:
return new EagerTensor(val, ctx.device_name); return new EagerTensor(val, ctx.device_name);
case float[] val:
case double val:
return new EagerTensor(val, ctx.device_name); return new EagerTensor(val, ctx.device_name);
case double[] val: case double[] val:
return new EagerTensor(val, ctx.device_name); return new EagerTensor(val, ctx.device_name);
case double[,] val:
return new EagerTensor(val, ctx.device_name);
case double[,,] val:
return new EagerTensor(val, ctx.device_name);
default: default:
throw new NotImplementedException($"convert_to_eager_tensor {value.GetType()}"); throw new NotImplementedException($"convert_to_eager_tensor {value.GetType()}");
} }


+ 12
- 4
test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs View File

@@ -26,10 +26,18 @@ namespace Tensorflow.UnitTest.TF_API
[TestMethod] [TestMethod]
public void InitTensorTest() public void InitTensorTest()
{ {
var a = tf.constant(new NDArray(new[, ,] { { { 1 }, { 2 }, { 3 } }, { { 4 }, { 5 }, { 6 } } }));
var b = tf.constant(new[, ,] { { { 1 }, { 2 }, { 3 } }, { { 4 }, { 5 }, { 6 } } });
//Test Result : a is OK , and b is error .
var a = tf.constant(np.array(new[, ,]
{
{ { 1 }, { 2 }, { 3 } },
{ { 4 }, { 5 }, { 6 } }
}));
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, a.shape)); Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, a.shape));

var b = tf.constant(new[, ,]
{
{ { 1 }, { 2 }, { 3 } },
{ { 4 }, { 5 }, { 6 } }
});
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, b.shape)); Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, b.shape));
} }


@@ -46,7 +54,7 @@ namespace Tensorflow.UnitTest.TF_API


[TestMethod] [TestMethod]
public void ConcatDoubleTest() public void ConcatDoubleTest()
{//double type has some error
{
var a = tf.constant(new[,] { { 1.0, 2.0 }, { 3.0, 4.0 } }); var a = tf.constant(new[,] { { 1.0, 2.0 }, { 3.0, 4.0 } });
var b = tf.constant(new[,] { { 5.0, 6.0 }, { 7.0, 8.0 } }); var b = tf.constant(new[,] { { 5.0, 6.0 }, { 7.0, 8.0 } });
var c = tf.constant(new[,] { { 9.0, 10.0 }, { 11.0, 12.0 } }); var c = tf.constant(new[,] { { 9.0, 10.0 }, { 11.0, 12.0 } });


Loading…
Cancel
Save