From 1d2fe5b2c6a75964a297fb8dd4d8ebd8d08eb85e Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 27 Dec 2018 11:28:55 -0600 Subject: [PATCH] Scalar and String constant creation. --- .../Operations/Operation.cs | 2 +- .../Operations/gen_math_ops.cs | 2 +- .../op_list_proto_array.bin | Bin .../op_list_proto_math.bin | Bin src/TensorFlowNET.Core/Operations/ops.cs | 10 ++- .../TensorFlowNET.Core.csproj | 4 +- src/TensorFlowNET.Core/Tensors/RefVariable.cs | 6 +- src/TensorFlowNET.Core/Tensors/Tensor.cs | 6 +- src/TensorFlowNET.Core/Tensors/TensorShape.cs | 17 ++--- src/TensorFlowNET.Core/Tensors/constant_op.cs | 41 ++++++++++++ src/TensorFlowNET.Core/Tensors/dtypes.cs | 59 +++++++++++++++++ src/TensorFlowNET.Core/Tensors/tensor_util.cs | 62 ++++++++++++++---- src/TensorFlowNET.Core/Tensors/tf.constant.cs | 14 ++++ src/TensorFlowNET.Core/tf.cs | 21 +----- test/TensorFlowNET.UnitTest/ConstantTest.cs | 28 ++++++++ test/TensorFlowNET.UnitTest/OperationsTest.cs | 6 -- test/TensorFlowNET.UnitTest/VariableTest.cs | 1 + 17 files changed, 214 insertions(+), 65 deletions(-) rename src/TensorFlowNET.Core/{Protobuf => Operations}/op_list_proto_array.bin (100%) rename src/TensorFlowNET.Core/{Protobuf => Operations}/op_list_proto_math.bin (100%) create mode 100644 src/TensorFlowNET.Core/Tensors/constant_op.cs create mode 100644 src/TensorFlowNET.Core/Tensors/dtypes.cs create mode 100644 src/TensorFlowNET.Core/Tensors/tf.constant.cs create mode 100644 test/TensorFlowNET.UnitTest/ConstantTest.cs diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 70ed5d58..d7908683 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -39,7 +39,7 @@ namespace Tensorflow _outputs = new Tensor[num_outputs]; for (int i = 0; i < num_outputs; i++) { - _outputs[i] = new Tensor(this, i, TF_DataType.TF_FLOAT); + _outputs[i] = new Tensor(this, i, output_types[i]); } _graph._add_op(this); diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 6aae72bd..9d5dbc21 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -25,7 +25,7 @@ namespace Tensorflow private static OpDefLibrary _InitOpDefLibrary() { // c_api.TF_GraphGetOpDef(g.Handle, op_type_name, buffer.Handle, status.Handle); - var bytes = File.ReadAllBytes("Tensorflow/op_list_proto_math.bin"); + var bytes = File.ReadAllBytes("Operations/op_list_proto_math.bin"); var op_list = OpList.Parser.ParseFrom(bytes); var op_def_lib = new OpDefLibrary(); op_def_lib.add_op_list(op_list); diff --git a/src/TensorFlowNET.Core/Protobuf/op_list_proto_array.bin b/src/TensorFlowNET.Core/Operations/op_list_proto_array.bin similarity index 100% rename from src/TensorFlowNET.Core/Protobuf/op_list_proto_array.bin rename to src/TensorFlowNET.Core/Operations/op_list_proto_array.bin diff --git a/src/TensorFlowNET.Core/Protobuf/op_list_proto_math.bin b/src/TensorFlowNET.Core/Operations/op_list_proto_math.bin similarity index 100% rename from src/TensorFlowNET.Core/Protobuf/op_list_proto_math.bin rename to src/TensorFlowNET.Core/Operations/op_list_proto_math.bin diff --git a/src/TensorFlowNET.Core/Operations/ops.cs b/src/TensorFlowNET.Core/Operations/ops.cs index f7eb3853..a491b6df 100644 --- a/src/TensorFlowNET.Core/Operations/ops.cs +++ b/src/TensorFlowNET.Core/Operations/ops.cs @@ -16,18 +16,16 @@ namespace Tensorflow return tf.Graph(); } - public static Tensor convert_to_tensor() + public static Tensor convert_to_tensor(object value, string name = "") { - return internal_convert_to_tensor(); + return internal_convert_to_tensor(value, name); } - private static Tensor internal_convert_to_tensor() + private static Tensor internal_convert_to_tensor(object value, string name = "") { - return null; + return tf.constant(value); } - - public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List inputs) { var op_desc = c_api.TF_NewOperation(graph.Handle, node_def.Op, node_def.Name); diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 8689e8c9..60437b9d 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -27,10 +27,10 @@ PreserveNewest - + PreserveNewest - + PreserveNewest diff --git a/src/TensorFlowNET.Core/Tensors/RefVariable.cs b/src/TensorFlowNET.Core/Tensors/RefVariable.cs index 129d0618..bbe5996a 100644 --- a/src/TensorFlowNET.Core/Tensors/RefVariable.cs +++ b/src/TensorFlowNET.Core/Tensors/RefVariable.cs @@ -7,6 +7,7 @@ namespace Tensorflow public class RefVariable : Variable { public bool _in_graph_mode = true; + public Tensor _initial_value; public RefVariable(object initial_value, TF_DataType trainable, @@ -16,9 +17,10 @@ namespace Tensorflow } - private void _init_from_args() + private void _init_from_args(object initial_value, + TF_DataType trainable) { - + _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); } } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 74ed8e5f..28a878f3 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -13,12 +13,12 @@ namespace Tensorflow /// public class Tensor { - public Operation op { get; } - public int value_index { get; } - public Graph graph => op.graph; + public Operation op { get; } public string name; + public object value; + public int value_index { get; } public TF_DataType dtype { get; } public IntPtr handle { get; } diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index 5c3cf87b..7a5b0d88 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -6,23 +6,14 @@ using System.Text; namespace Tensorflow { + /// + /// Represents the shape of a `Tensor`. + /// public class TensorShape : Shape { - public TensorShape(params int[] shape) : base(shape) + public TensorShape(params int[] dims) : base(dims) { } - - public TensorShape as_shape() - { - return this; - } - - public TensorShapeProto as_proto() - { - TensorShapeProto dim = new TensorShapeProto(); - - return new TensorShapeProto(dim); - } } } diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs new file mode 100644 index 00000000..05f37685 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -0,0 +1,41 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class constant_op + { + /// + /// Creates a constant tensor. + /// + /// The resulting tensor is populated with values of type `dtype`, as + /// specified by arguments `value` and (optionally) `shape` + /// + /// A constant value (or list) of output type `dtype`. + /// The type of the elements of the resulting tensor. + /// Optional dimensions of resulting tensor. + /// Optional name for the tensor. + /// Boolean that enables verification of a shape of values. + /// + public static Tensor Create(object value, TF_DataType dtype = TF_DataType.DtInvalid, TensorShape shape = null, string name = "Const", bool verify_shape = false) + { + Graph g = ops.get_default_graph(); + var tensor_value = new AttrValue(); + var tensor_pb = tensor_util.make_tensor_proto(value, dtype, shape, verify_shape); + tensor_value.Tensor = tensor_pb; + var dtype_value = new AttrValue + { + Type = tensor_value.Tensor.Dtype, + }; + + var attrs = new Dictionary(); + attrs["dtype"] = dtype_value; + attrs["value"] = tensor_value; + var const_tensor = g.create_op("Const", null, new TF_DataType[] { (TF_DataType)dtype_value.Type }, attrs: attrs).outputs[0]; + const_tensor.value = value; + + return const_tensor; + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs new file mode 100644 index 00000000..d8f4bcba --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -0,0 +1,59 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public static class dtypes + { + public static TF_DataType as_dtype(Type type) + { + TF_DataType dtype = TF_DataType.DtInvalid; + + switch (type.Name) + { + case "Int32": + dtype = TF_DataType.TF_INT32; + break; + case "Single": + dtype = TF_DataType.TF_FLOAT; + break; + case "Double": + dtype = TF_DataType.TF_DOUBLE; + break; + case "String": + dtype = TF_DataType.TF_STRING; + break; + default: + throw new Exception("Not Implemented"); + } + + return dtype; + } + + public static DataType as_datatype_enum(this TF_DataType type) + { + DataType dtype = DataType.DtInvalid; + + switch (type) + { + case TF_DataType.TF_INT32: + dtype = DataType.DtInt32; + break; + case TF_DataType.TF_FLOAT: + dtype = DataType.DtFloat; + break; + case TF_DataType.TF_DOUBLE: + dtype = DataType.DtDouble; + break; + case TF_DataType.TF_STRING: + dtype = DataType.DtString; + break; + default: + throw new Exception("Not Implemented"); + } + + return dtype; + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 1c6da7bc..70d5cf9f 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -8,44 +8,84 @@ namespace Tensorflow { public static class tensor_util { - public static TensorProto make_tensor_proto(object values, Type dtype = null) + public static TensorProto make_tensor_proto(object values, TF_DataType dtype = TF_DataType.DtInvalid, Shape shape = null, bool verify_shape = false) { NDArray nparray; TensorProto tensor_proto = null; - TensorShape tensor_shape = new TensorShape(0); + TF_DataType numpy_dtype; + if(shape is null) + { + shape = new Shape(); + } switch (values) { + case int val: + nparray = np.asarray(val); + numpy_dtype = dtypes.as_dtype(nparray.dtype); + tensor_proto = new tensor_pb2.TensorProto + { + Dtype = numpy_dtype.as_datatype_enum(), + TensorShape = shape.as_shape(nparray.shape).as_proto() + }; + tensor_proto.IntVal.Add(val); + break; case float val: - nparray = np.array(new float[] { val }, np.float32); + nparray = np.asarray(val); + numpy_dtype = dtypes.as_dtype(nparray.dtype); tensor_proto = new tensor_pb2.TensorProto { - Dtype = DataType.DtFloat, - TensorShape = tensor_shape.as_shape().as_proto() + Dtype = numpy_dtype.as_datatype_enum(), + TensorShape = shape.as_shape(nparray.shape).as_proto() }; tensor_proto.FloatVal.Add(val); break; case double val: - nparray = np.array(new double[] { val }, np.float64); + nparray = np.asarray(val); + numpy_dtype = dtypes.as_dtype(nparray.dtype); tensor_proto = new tensor_pb2.TensorProto { - Dtype = DataType.DtDouble, - TensorShape = tensor_shape.as_shape().as_proto() + Dtype = numpy_dtype.as_datatype_enum(), + TensorShape = shape.as_shape(nparray.shape).as_proto() }; tensor_proto.DoubleVal.Add(val); break; case string val: - nparray = np.array(new string[] { val }, np.chars); + nparray = np.asarray(val); + numpy_dtype = dtypes.as_dtype(nparray.dtype); tensor_proto = new tensor_pb2.TensorProto { - Dtype = DataType.DtString, - TensorShape = tensor_shape.as_shape().as_proto() + Dtype = numpy_dtype.as_datatype_enum(), + TensorShape = shape.as_shape(nparray.shape).as_proto() }; tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFrom(val, Encoding.UTF8)); break; + default: + throw new Exception("Not Implemented"); } return tensor_proto; } + + public static TensorShape as_shape(this Shape shape, int[] dims) + { + return new TensorShape(dims); + } + + public static TensorShapeProto as_proto(this TensorShape tshape) + { + TensorShapeProto shape = new TensorShapeProto(); + + for (int i = 0; i < tshape.NDim; i++) + { + var dim = new TensorShapeProto.Types.Dim(); + dim.Size = tshape.Dimensions[i]; + dim.Name = $"{dim}_1"; + + shape.Dim.Add(dim); + } + + return shape; + } } } diff --git a/src/TensorFlowNET.Core/Tensors/tf.constant.cs b/src/TensorFlowNET.Core/Tensors/tf.constant.cs new file mode 100644 index 00000000..82955028 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/tf.constant.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public static partial class tf + { + public static Tensor constant(object value, TF_DataType dtype = TF_DataType.DtInvalid, TensorShape shape = null, string name = "Const", bool verify_shape = false) + { + return constant_op.Create(value, dtype, shape, name, verify_shape); + } + } +} diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs index 09783aa1..d50ccdb1 100644 --- a/src/TensorFlowNET.Core/tf.cs +++ b/src/TensorFlowNET.Core/tf.cs @@ -8,7 +8,7 @@ using Tensorflow.Eager; namespace Tensorflow { - public static class tf + public static partial class tf { public static TF_DataType float32 = TF_DataType.TF_FLOAT; public static TF_DataType chars = TF_DataType.TF_STRING; @@ -32,25 +32,6 @@ namespace Tensorflow return gen_array_ops.placeholder(dtype, shape); } - public static unsafe Tensor constant(object value) - { - var g = ops.get_default_graph(); - var tensor_value = new attr_value_pb2.AttrValue(); - var tensor_pb = tensor_util.make_tensor_proto(value); - tensor_value.Tensor = tensor_pb; - var dtype_value = new attr_value_pb2.AttrValue - { - Type = tensor_value.Tensor.Dtype, - }; - - var attrs = new Dictionary(); - attrs["dtype"] = dtype_value; - attrs["value"] = tensor_value; - var const_tensor = g.create_op("Const", null, new TF_DataType[] { (TF_DataType)dtype_value.Type }, attrs: attrs).outputs[0]; - - return const_tensor; - } - public static void enable_eager_execution() { context.default_execution_mode = Context.EAGER_MODE; diff --git a/test/TensorFlowNET.UnitTest/ConstantTest.cs b/test/TensorFlowNET.UnitTest/ConstantTest.cs new file mode 100644 index 00000000..f3b30e6a --- /dev/null +++ b/test/TensorFlowNET.UnitTest/ConstantTest.cs @@ -0,0 +1,28 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow; + +namespace TensorFlowNET.UnitTest +{ + [TestClass] + public class ConstantTest + { + Tensor tensor; + + [TestMethod] + public void ScalarConst() + { + tensor = tf.constant(8); // int + tensor = tf.constant(6.0f); // float + tensor = tf.constant(6.0); // double + } + + [TestMethod] + public void StringConst() + { + tensor = tf.constant("Elephant"); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index 23b17070..47b849d9 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -9,12 +9,6 @@ namespace TensorFlowNET.UnitTest [TestClass] public class OperationsTest { - [TestMethod] - public void constant() - { - var x = tf.constant(4.0f); - } - [TestMethod] public void placeholder() { diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs index a761d93e..970a5670 100644 --- a/test/TensorFlowNET.UnitTest/VariableTest.cs +++ b/test/TensorFlowNET.UnitTest/VariableTest.cs @@ -9,6 +9,7 @@ namespace TensorFlowNET.UnitTest [TestClass] public class VariableTest { + [TestMethod] public void Creating() { var mammal = tf.Variable("Elephant", tf.chars);