diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln index 8936dd3d..3b2df95d 100644 --- a/TensorFlow.NET.sln +++ b/TensorFlow.NET.sln @@ -9,7 +9,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\T EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "test\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj", "{1FE60088-157C-4140-91AB-E96B915E4BAE}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{6ACED8FF-F08E-40E6-A75D-D01BAAA41072}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{DA680126-DA60-4CE3-9094-72C355C081D3}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -29,10 +29,10 @@ Global {1FE60088-157C-4140-91AB-E96B915E4BAE}.Debug|Any CPU.Build.0 = Debug|Any CPU {1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.ActiveCfg = Release|Any CPU {1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.Build.0 = Release|Any CPU - {6ACED8FF-F08E-40E6-A75D-D01BAAA41072}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {6ACED8FF-F08E-40E6-A75D-D01BAAA41072}.Debug|Any CPU.Build.0 = Debug|Any CPU - {6ACED8FF-F08E-40E6-A75D-D01BAAA41072}.Release|Any CPU.ActiveCfg = Release|Any CPU - {6ACED8FF-F08E-40E6-A75D-D01BAAA41072}.Release|Any CPU.Build.0 = Release|Any CPU + {DA680126-DA60-4CE3-9094-72C355C081D3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {DA680126-DA60-4CE3-9094-72C355C081D3}.Debug|Any CPU.Build.0 = Debug|Any CPU + {DA680126-DA60-4CE3-9094-72C355C081D3}.Release|Any CPU.ActiveCfg = Release|Any CPU + {DA680126-DA60-4CE3-9094-72C355C081D3}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/src/TensorFlowNET.Core/Operations/ops.cs b/src/TensorFlowNET.Core/Operations/ops.cs index a491b6df..a0ea4531 100644 --- a/src/TensorFlowNET.Core/Operations/ops.cs +++ b/src/TensorFlowNET.Core/Operations/ops.cs @@ -18,12 +18,8 @@ namespace Tensorflow public static Tensor convert_to_tensor(object value, string name = "") { - return internal_convert_to_tensor(value, name); - } - - private static Tensor internal_convert_to_tensor(object value, string name = "") - { - return tf.constant(value); + var nd = tensor_util.convert_to_numpy_ndarray(value); + return tf.constant(nd, name); } public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List inputs) diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 60437b9d..c35f0aa2 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -4,6 +4,18 @@ netstandard2.0 TensorFlow.NET Tensorflow + 0.0.2 + Haiping Chen + SciSharp.org + true + Apache 2.0 + https://github.com/SciSharp/TensorFlow.NET + git + https://github.com/SciSharp + https://avatars3.githubusercontent.com/u/44989469?s=200&v=4 + TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET + TensorFlow binding for .NET Standard. + 0.0.2.0 @@ -11,12 +23,17 @@ DEBUG;TRACE + + true + + + diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index 05f37685..0beda95e 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -1,4 +1,5 @@ -using System; +using NumSharp.Core; +using System; using System.Collections.Generic; using System.Text; @@ -18,11 +19,11 @@ namespace Tensorflow /// Optional name for the tensor. /// Boolean that enables verification of a shape of values. /// - public static Tensor Create(object value, TF_DataType dtype = TF_DataType.DtInvalid, TensorShape shape = null, string name = "Const", bool verify_shape = false) + public static Tensor Create(NDArray nd, string name = "Const", bool verify_shape = false) { Graph g = ops.get_default_graph(); var tensor_value = new AttrValue(); - var tensor_pb = tensor_util.make_tensor_proto(value, dtype, shape, verify_shape); + var tensor_pb = tensor_util.make_tensor_proto(nd, verify_shape); tensor_value.Tensor = tensor_pb; var dtype_value = new AttrValue { @@ -33,7 +34,7 @@ namespace Tensorflow attrs["dtype"] = dtype_value; attrs["value"] = tensor_value; var const_tensor = g.create_op("Const", null, new TF_DataType[] { (TF_DataType)dtype_value.Type }, attrs: attrs).outputs[0]; - const_tensor.value = value; + const_tensor.value = nd.Data(); return const_tensor; } diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 70d5cf9f..052adee6 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -1,4 +1,5 @@ using NumSharp.Core; +using NumSharp.Core.Interfaces; using System; using System.Collections.Generic; using System.Text; @@ -8,66 +9,70 @@ namespace Tensorflow { public static class tensor_util { - public static TensorProto make_tensor_proto(object values, TF_DataType dtype = TF_DataType.DtInvalid, Shape shape = null, bool verify_shape = false) + public static TensorProto make_tensor_proto(NDArray nd, bool verify_shape = false) { - NDArray nparray; - TensorProto tensor_proto = null; - TF_DataType numpy_dtype; - if(shape is null) + var shape = nd.Storage.Shape; + + var numpy_dtype = dtypes.as_dtype(nd.dtype); + var tensor_proto = new tensor_pb2.TensorProto + { + Dtype = numpy_dtype.as_datatype_enum(), + TensorShape = shape.as_shape(nd.shape).as_proto() + }; + + switch (nd.dtype.Name) { - shape = new Shape(); + case "Int32": + tensor_proto.IntVal.AddRange(nd.Data()); + break; + case "Single": + tensor_proto.FloatVal.AddRange(nd.Data()); + break; + case "Double": + tensor_proto.DoubleVal.AddRange(nd.Data()); + break; + case "String": + tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFrom(nd.Data()[0], Encoding.UTF8)); + break; + default: + throw new Exception("Not Implemented"); } + return tensor_proto; + } + + public static NDArray convert_to_numpy_ndarray(object values) + { + NDArray nd; + switch (values) { + case NDArray val: + nd = val; + break; case int val: - nparray = np.asarray(val); - numpy_dtype = dtypes.as_dtype(nparray.dtype); - tensor_proto = new tensor_pb2.TensorProto - { - Dtype = numpy_dtype.as_datatype_enum(), - TensorShape = shape.as_shape(nparray.shape).as_proto() - }; - tensor_proto.IntVal.Add(val); + nd = np.asarray(val); + break; + case int[] val: + nd = np.array(val); break; case float val: - nparray = np.asarray(val); - numpy_dtype = dtypes.as_dtype(nparray.dtype); - tensor_proto = new tensor_pb2.TensorProto - { - Dtype = numpy_dtype.as_datatype_enum(), - TensorShape = shape.as_shape(nparray.shape).as_proto() - }; - tensor_proto.FloatVal.Add(val); + nd = np.asarray(val); break; case double val: - nparray = np.asarray(val); - numpy_dtype = dtypes.as_dtype(nparray.dtype); - tensor_proto = new tensor_pb2.TensorProto - { - Dtype = numpy_dtype.as_datatype_enum(), - TensorShape = shape.as_shape(nparray.shape).as_proto() - }; - tensor_proto.DoubleVal.Add(val); + nd = np.asarray(val); break; case string val: - nparray = np.asarray(val); - numpy_dtype = dtypes.as_dtype(nparray.dtype); - tensor_proto = new tensor_pb2.TensorProto - { - Dtype = numpy_dtype.as_datatype_enum(), - TensorShape = shape.as_shape(nparray.shape).as_proto() - }; - tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFrom(val, Encoding.UTF8)); + nd = np.asarray(val); break; default: throw new Exception("Not Implemented"); } - return tensor_proto; + return nd; } - public static TensorShape as_shape(this Shape shape, int[] dims) + public static TensorShape as_shape(this IShape shape, int[] dims) { return new TensorShape(dims); } @@ -80,7 +85,7 @@ namespace Tensorflow { var dim = new TensorShapeProto.Types.Dim(); dim.Size = tshape.Dimensions[i]; - dim.Name = $"{dim}_1"; + dim.Name = $"dim_{i}"; shape.Dim.Add(dim); } diff --git a/src/TensorFlowNET.Core/Tensors/tf.constant.cs b/src/TensorFlowNET.Core/Tensors/tf.constant.cs index 82955028..df56b835 100644 --- a/src/TensorFlowNET.Core/Tensors/tf.constant.cs +++ b/src/TensorFlowNET.Core/Tensors/tf.constant.cs @@ -1,4 +1,5 @@ -using System; +using NumSharp.Core; +using System; using System.Collections.Generic; using System.Text; @@ -6,9 +7,9 @@ namespace Tensorflow { public static partial class tf { - public static Tensor constant(object value, TF_DataType dtype = TF_DataType.DtInvalid, TensorShape shape = null, string name = "Const", bool verify_shape = false) + public static Tensor constant(NDArray value, string name = "Const", bool verify_shape = false) { - return constant_op.Create(value, dtype, shape, name, verify_shape); + return constant_op.Create(value, name, verify_shape); } } } diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj index 627d1088..bf59b53f 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj @@ -5,6 +5,10 @@ netcoreapp2.1 + + + + diff --git a/test/TensorFlowNET.UnitTest/ConstantTest.cs b/test/TensorFlowNET.UnitTest/ConstantTest.cs index f3b30e6a..76348a77 100644 --- a/test/TensorFlowNET.UnitTest/ConstantTest.cs +++ b/test/TensorFlowNET.UnitTest/ConstantTest.cs @@ -1,4 +1,5 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp.Core; using System; using System.Collections.Generic; using System.Text; @@ -24,5 +25,17 @@ namespace TensorFlowNET.UnitTest { tensor = tf.constant("Elephant"); } + + [TestMethod] + public void NDimConst() + { + var nd = np.array(new int[][] + { + new int[]{ 1, 2, 3 }, + new int[]{ 4, 5, 6 } + }); + + tensor = tf.constant(nd); + } } } diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj index de2eb99e..e93f4713 100644 --- a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj @@ -11,10 +11,15 @@ true + + true + + +