From 11572bf77032505571a337a3a233ba213899de25 Mon Sep 17 00:00:00 2001 From: haiping008 Date: Mon, 11 Feb 2019 17:41:59 -0600 Subject: [PATCH] bunch of updates. --- src/TensorFlowNET.Core/APIs/tf.random.cs | 26 +++ .../Framework/random_seed.py.cs | 14 ++ src/TensorFlowNET.Core/Graphs/Graph.cs | 5 +- .../Operations/OpDefLibrary.cs | 25 ++- .../Operations/array_ops.py.cs | 10 +- .../Operations/gen_array_ops.cs | 2 +- .../Operations/gen_random_ops.py.cs | 33 ++++ .../Operations/math_ops.py.cs | 2 + .../Operations/random_ops.py.cs | 35 ++++ .../Sessions/BaseSession.cs | 12 +- src/TensorFlowNET.Core/Tensors/constant_op.cs | 27 +-- src/TensorFlowNET.Core/Tensors/dtypes.cs | 14 ++ src/TensorFlowNET.Core/Tensors/tensor_util.cs | 167 ++++++++++++++++-- src/TensorFlowNET.Core/Tensors/tf.constant.cs | 10 +- .../Train/Saving/BaseSaverBuilder.cs | 4 +- src/TensorFlowNET.Core/ops.py.cs | 81 ++++----- test/TensorFlowNET.UnitTest/ConstantTest.cs | 2 +- test/TensorFlowNET.UnitTest/TrainSaverTest.cs | 29 +++ 18 files changed, 405 insertions(+), 93 deletions(-) create mode 100644 src/TensorFlowNET.Core/APIs/tf.random.cs create mode 100644 src/TensorFlowNET.Core/Framework/random_seed.py.cs create mode 100644 src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs create mode 100644 src/TensorFlowNET.Core/Operations/random_ops.py.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.random.cs b/src/TensorFlowNET.Core/APIs/tf.random.cs new file mode 100644 index 00000000..4ec40d11 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.random.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public static partial class tf + { + /// + /// Outputs random values from a normal distribution. + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor random_normal(int[] shape, + float mean = 0.0f, + float stddev = 1.0f, + TF_DataType dtype = TF_DataType.TF_FLOAT, + int? seed = null, + string name = "") => random_ops.random_normal(shape, mean, stddev, dtype, seed, name); + } +} diff --git a/src/TensorFlowNET.Core/Framework/random_seed.py.cs b/src/TensorFlowNET.Core/Framework/random_seed.py.cs new file mode 100644 index 00000000..eb2ea386 --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/random_seed.py.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class random_seed + { + public static (int?, int?) get_seed(int? op_seed = null) + { + return (null, null); + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index a926a57f..79e9fbe5 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -141,10 +141,9 @@ namespace Tensorflow } if (String.IsNullOrEmpty(name)) - { name = op_type; - } - + // If a names ends with a '/' it is a "name scope" and we use it as-is, + // after removing the trailing '/'. name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name); var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index 76dd318a..62b4821a 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -14,7 +14,7 @@ namespace Tensorflow { public Operation _apply_op_helper(string op_type_name, string name = "", dynamic args = null) { - var keywords = ConvertToDict(args); + Dictionary keywords = ConvertToDict(args); var g = ops.get_default_graph(); var op_def = g.GetOpDef(op_type_name); @@ -42,7 +42,8 @@ namespace Tensorflow var attrs = new Dictionary(); var inputs = new List(); var input_types = new List(); - + dynamic values = null; + return Python.with(new ops.name_scope(name), scope => { var inferred_from = new Dictionary(); @@ -53,7 +54,17 @@ namespace Tensorflow foreach (var input_arg in op_def.InputArg) { var input_name = input_arg.Name; - var values = keywords[input_name]; + + if (keywords.ContainsKey(input_name)) + values = keywords[input_name]; + else if (keywords.ContainsKey(input_name + "_")) + { + input_name += "_"; + values = keywords[input_name]; + } + else + throw new TypeError("No argument for input " + input_name); + // Goals: // * Convert values to Tensors if it contains constants. // * Verify that values is a list if that matches the input_arg's @@ -92,8 +103,8 @@ namespace Tensorflow values = ops.internal_convert_n_to_tensor(values, name: input_arg.Name, - dtype: dtype, - preferred_dtype: default_dtype, + dtype: dtype.as_tf_dtype(), + preferred_dtype: default_dtype.as_tf_dtype(), as_ref: input_arg.IsRef); } else @@ -107,9 +118,9 @@ namespace Tensorflow values = ops.internal_convert_to_tensor(values, name: input_name, - dtype: dtype, + dtype: dtype.as_tf_dtype(), as_ref: input_arg.IsRef, - preferred_dtype: default_dtype); + preferred_dtype: default_dtype.as_tf_dtype()); //if (!String.IsNullOrEmpty(input_arg.TypeAttr)) //attrs[input_arg.TypeAttr] = values.dtype; diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index 28ff42cf..40e9f38f 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -7,6 +7,8 @@ namespace Tensorflow { public class array_ops { + public static Tensor placeholder_with_default(T input, int[] shape, string name = "") => gen_array_ops.placeholder_with_default(input, shape, name); + public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "") { dtype = dtype.as_base_dtype(); @@ -35,13 +37,13 @@ namespace Tensorflow var nd = np.zeros(shape); if (shape.Size < 1000) { - return constant_op.constant(nd, name); + return constant_op.constant(nd, name: name); } else { tShape = constant_op._tensor_shape_tensor_conversion_function(shape.as_shape()); var c = constant_op.constant(0); - return gen_array_ops.fill(tShape, c, name); + return gen_array_ops.fill(tShape, c, name: name); } } @@ -99,7 +101,7 @@ namespace Tensorflow if (optimize && input_shape.is_fully_defined()) { var nd = np.array(input_tensor.shape, out_type.as_numpy_datatype()); - return constant_op.constant(nd, name); + return constant_op.constant(nd, name: name); } } @@ -122,7 +124,7 @@ namespace Tensorflow if (input_shape.is_fully_defined()) { var nd = np.array(input_tensor.shape, out_type.as_numpy_datatype()); - return constant_op.constant(nd, name); + return constant_op.constant(nd, name: name); } } diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index b3a3e607..8bf345fc 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -113,7 +113,7 @@ namespace Tensorflow /// /// /// - public static Tensor placeholder_with_default(T input, TensorShape shape, string name = "") + public static Tensor placeholder_with_default(T input, int[] shape, string name = "") { var _op = _op_def_lib._apply_op_helper("PlaceholderWithDefault", name, new { input, shape, name }); return _op.outputs[0]; diff --git a/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs new file mode 100644 index 00000000..643e7d0a --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs @@ -0,0 +1,33 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class gen_random_ops + { + public static OpDefLibrary _op_def_lib = new OpDefLibrary(); + + /// + /// Outputs random values from a normal distribution. + /// + /// + /// + /// + /// + /// + /// + public static Tensor random_standard_normal(Tensor shape, TF_DataType dtype = TF_DataType.DtInvalid, int? seed = null, int? seed2 = null, string name = "") + { + if (!seed.HasValue) + seed = 0; + if (!seed2.HasValue) + seed2 = 0; + + var _op = _op_def_lib._apply_op_helper("RandomStandardNormal", name: name, + args: new { shape, dtype, seed, seed2 }); + + return _op.outputs[0]; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/math_ops.py.cs b/src/TensorFlowNET.Core/Operations/math_ops.py.cs index ef273c49..1e78285a 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.py.cs @@ -6,6 +6,8 @@ namespace Tensorflow { public class math_ops { + public static Tensor add(Tensor x, Tensor y, string name = "") => gen_math_ops.add(x, y, name); + public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = "") { var base_type = dtype.as_base_dtype(); diff --git a/src/TensorFlowNET.Core/Operations/random_ops.py.cs b/src/TensorFlowNET.Core/Operations/random_ops.py.cs new file mode 100644 index 00000000..3334a2d9 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/random_ops.py.cs @@ -0,0 +1,35 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class random_ops + { + public static Tensor random_normal(int[] shape, + float mean = 0.0f, + float stddev = 1.0f, + TF_DataType dtype = TF_DataType.TF_FLOAT, + int? seed = null, + string name = "") + { + return Python.with(new ops.name_scope(name, "random_normal", new object[] { shape, mean, stddev }), scope => + { + var shape_tensor = _ShapeTensor(shape); + var mean_tensor = ops.convert_to_tensor(mean, dtype: dtype, name: "mean"); + var stddev_tensor = ops.convert_to_tensor(stddev, dtype: dtype, name = "stddev"); + var (seed1, seed2) = random_seed.get_seed(seed); + var rnd = gen_random_ops.random_standard_normal(shape_tensor, dtype: dtype, seed: seed1, seed2: seed2); + var mul = rnd * stddev_tensor; + var value = math_ops.add(mul, mean_tensor, name: name); + return value; + }); + } + + private static Tensor _ShapeTensor(int[] shape) + { + return ops.convert_to_tensor(shape, name: "shape"); + } + } +} + diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 5a828029..8044e904 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -55,11 +55,19 @@ namespace Tensorflow var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); switch(subfeed.Value) { + case float floatVal: + feed_dict_tensor[subfeed_t] = (NDArray)floatVal; + break; + case int intVal: + feed_dict_tensor[subfeed_t] = (NDArray)intVal; + break; case string str: - feed_dict_tensor[subfeed_t] = np.array(str); - feed_map[subfeed_t.name] = new Tuple(subfeed_t, subfeed.Value); + feed_dict_tensor[subfeed_t] = (NDArray)str; break; + default: + throw new NotImplementedException("_run subfeed"); } + feed_map[subfeed_t.name] = new Tuple(subfeed_t, subfeed.Value); } } diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index e1f930f2..f65d71c4 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -19,7 +19,12 @@ namespace Tensorflow /// Optional name for the tensor. /// Boolean that enables verification of a shape of values. /// - public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false) + public static Tensor constant(object value, TF_DataType dtype = TF_DataType.DtInvalid, int[] shape = null, string name = "Const") + { + return _constant_impl(value, dtype, shape, name, verify_shape: false, allow_broadcast: true); + } + + private static Tensor _constant_impl(object value, TF_DataType dtype, int[] shape, string name, bool verify_shape, bool allow_broadcast) { if (tf.context.executing_eagerly()) { @@ -27,13 +32,13 @@ namespace Tensorflow } Graph g = ops.get_default_graph(); - var tensor_pb = tensor_util.make_tensor_proto(nd, verify_shape); - var tensor_value = new AttrValue - { - Type = tensor_pb.Dtype, - Tensor = tensor_pb - }; - + var tensor_value = new AttrValue(); + tensor_value.Tensor = tensor_util.make_tensor_proto(value, + dtype: dtype, + shape: shape, + verify_shape: verify_shape, + allow_broadcast: allow_broadcast); + var dtype_value = new AttrValue { Type = tensor_value.Tensor.Dtype, @@ -44,8 +49,8 @@ namespace Tensorflow attrs["dtype"] = dtype_value; var op = g.create_op("Const", - null, - new TF_DataType[] { (TF_DataType)dtype_value.Type }, + new Tensor[0], + new TF_DataType[] { dtype_value.Type.as_tf_dtype() }, attrs: attrs, name: name); @@ -81,7 +86,7 @@ namespace Tensorflow if (string.IsNullOrEmpty(name)) name = "shape_as_tensor"; - return constant_op.constant(s_list, name); + return constant_op.constant(s_list, name: name); } } } diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index 31bfe3e2..4c0bd693 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -83,6 +83,20 @@ namespace Tensorflow type; } + public static TF_DataType as_tf_dtype(this DataType type) + { + TF_DataType dtype = TF_DataType.DtInvalid; + + switch (type) + { + default: + Enum.TryParse(((int)type).ToString(), out dtype); + break; + } + + return dtype; + } + public static TF_DataType as_ref(this TF_DataType type) { return (int)type < 100 ? diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 59aa3639..8bdac8d2 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -10,33 +10,166 @@ namespace Tensorflow { public static class tensor_util { - public static TensorProto make_tensor_proto(NDArray nd, bool verify_shape = false) + public static TF_DataType[] _TENSOR_CONTENT_TYPES = { - var shape = nd.Storage.Shape; + TF_DataType.TF_FLOAT, TF_DataType.TF_DOUBLE, TF_DataType.TF_INT32, TF_DataType.TF_UINT8, TF_DataType.TF_INT16, + TF_DataType.TF_INT8, TF_DataType.TF_INT64, TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16, + TF_DataType.TF_QUINT16, TF_DataType.TF_QINT32, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64 + }; + + /// + /// Create a TensorProto. + /// + /// + /// + /// + /// + /// + /// + public static TensorProto make_tensor_proto(object values, TF_DataType dtype = TF_DataType.DtInvalid, int[] shape = null, bool verify_shape = false, bool allow_broadcast = false) + { + if (allow_broadcast && verify_shape) + throw new ValueError("allow_broadcast and verify_shape are not both allowed."); + if (values is TensorProto tp) + return tp; + + if (dtype != TF_DataType.DtInvalid) + ; + + bool is_quantized = new TF_DataType[] + { + TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QUINT16, + TF_DataType.TF_QINT32 + }.Contains(dtype); + + // We first convert value to a numpy array or scalar. + NDArray nparray = null; + + if (values is NDArray nd) + { + nparray = nd; + } + else + { + if (values == null) + throw new ValueError("None values not supported."); + + switch (values) + { + /*case bool boolVal: + nparray = boolVal; + break;*/ + case int intVal: + nparray = intVal; + break; + case int[] intVals: + nparray = np.array(intVals); + break; + case float floatVal: + nparray = floatVal; + break; + case double doubleVal: + nparray = doubleVal; + break; + case string strVal: + nparray = strVal; + break; + default: + throw new Exception("make_tensor_proto Not Implemented"); + } + } + + var numpy_dtype = dtypes.as_dtype(nparray.dtype); + if (numpy_dtype == TF_DataType.DtInvalid) + throw new TypeError($"Unrecognized data type: {nparray.dtype}"); + + // If dtype was specified and is a quantized type, we convert + // numpy_dtype back into the quantized version. + if (is_quantized) + numpy_dtype = dtype; + + bool is_same_size = false; + int shape_size = 0; + + // If shape is not given, get the shape from the numpy array. + if (shape == null) + { + shape = nparray.shape; + is_same_size = true; + shape_size = nparray.size; + } + else + { + throw new NotImplementedException("make_tensor_proto shape not implemented"); + } - var numpy_dtype = dtypes.as_dtype(nd.dtype); var tensor_proto = new tensor_pb2.TensorProto { Dtype = numpy_dtype.as_datatype_enum(), - TensorShape = shape.reshape(nd.shape).as_proto() + TensorShape = tensor_util.as_shape(shape) }; - switch (nd.dtype.Name) + if (is_same_size && _TENSOR_CONTENT_TYPES.Contains(numpy_dtype) && shape_size > 1) + { + var bytes = new List(); + var nd2 = nparray.ravel(); + switch (nparray.dtype.Name) + { + case "Int32": + nd2.Data().Select(x => + { + bytes.AddRange(BitConverter.GetBytes(x)); + return x; + }).ToArray(); + break; + case "Single": + nd2.Data().Select(x => + { + bytes.AddRange(BitConverter.GetBytes(x)); + return x; + }).ToArray(); + break; + case "Double": + nd2.Data().Select(x => + { + bytes.AddRange(BitConverter.GetBytes(x)); + return x; + }).ToArray(); + break; + default: + throw new Exception("make_tensor_proto Not Implemented"); + } + tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes.ToArray()); + return tensor_proto; + } + + if (numpy_dtype == TF_DataType.TF_STRING && !(values is NDArray) && values is string str) { + tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFromUtf8(str)); + return tensor_proto; + } + + var proto_values = nparray.ravel(); + + switch (nparray.dtype.Name) + { + case "Bool": + tensor_proto.BoolVal.AddRange(proto_values.Data()); + break; case "Int32": - tensor_proto.IntVal.AddRange(nd.Data()); + tensor_proto.IntVal.AddRange(proto_values.Data()); break; case "Single": - tensor_proto.FloatVal.AddRange(nd.Data()); + tensor_proto.FloatVal.AddRange(proto_values.Data()); break; case "Double": - tensor_proto.DoubleVal.AddRange(nd.Data()); + tensor_proto.DoubleVal.AddRange(proto_values.Data()); break; case "String": - tensor_proto.StringVal.AddRange(nd.Data().Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x))); + tensor_proto.StringVal.AddRange(proto_values.Data().Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x.ToString()))); break; default: - throw new Exception("Not Implemented"); + throw new Exception("make_tensor_proto Not Implemented"); } return tensor_proto; @@ -73,14 +206,24 @@ namespace Tensorflow return nd; } - public static TensorShapeProto as_shape(long[] dims) + public static TensorShapeProto as_shape(T[] dims) { TensorShapeProto shape = new TensorShapeProto(); for (int i = 0; i < dims.Length; i++) { var dim = new TensorShapeProto.Types.Dim(); - dim.Size = dims[i]; + switch(dims[i]) + { + case int n: + dim.Size = n; + break; + case long l: + dim.Size = l; + break; + default: + throw new NotImplementedException("as_shape Not Implemented"); + } dim.Name = $"dim_{i}"; shape.Dim.Add(dim); diff --git a/src/TensorFlowNET.Core/Tensors/tf.constant.cs b/src/TensorFlowNET.Core/Tensors/tf.constant.cs index 26711316..93bbc459 100644 --- a/src/TensorFlowNET.Core/Tensors/tf.constant.cs +++ b/src/TensorFlowNET.Core/Tensors/tf.constant.cs @@ -7,14 +7,8 @@ namespace Tensorflow { public static partial class tf { - public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false) - { - return constant_op.constant(nd, name, verify_shape); - } + public static Tensor constant(NDArray nd, string name = "Const") => constant_op.constant(nd, name: name); - public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "") - { - return array_ops.zeros(shape, dtype, name); - } + public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "") => array_ops.zeros(shape, dtype, name); } } diff --git a/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs index 7ce6987a..0c7875f9 100644 --- a/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs +++ b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs @@ -84,8 +84,8 @@ namespace Tensorflow name = scope; // Add a placeholder string tensor for the filename. - var filename_tensor = gen_array_ops.placeholder_with_default( string.IsNullOrEmpty(filename) ? "model" : filename, shape: new TensorShape(), name: "filename"); - filename_tensor = gen_array_ops.placeholder_with_default(filename_tensor, shape: new TensorShape(), name: "Const"); + var filename_tensor = array_ops.placeholder_with_default(string.IsNullOrEmpty(filename) ? "model" : filename, shape: new int[0], name: "filename"); + filename_tensor = gen_array_ops.placeholder_with_default(filename_tensor, shape: new int[0], name: "Const"); // Keep the name "Const" for backwards compatibility. // Add the save ops. diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index e43881c1..74575cff 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -68,16 +68,14 @@ namespace Tensorflow /// /// /// - public static Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "") + public static Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "", TF_DataType preferred_dtype = TF_DataType.DtInvalid) { - switch (value) - { - case Tensor val: - return val; - default: - var nd = tensor_util.convert_to_numpy_ndarray(value); - return constant_op.constant(nd, name); - } + return convert_to_tensor_v2(value, dtype, preferred_dtype, name); + } + + public static Tensor convert_to_tensor_v2(object value, TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype_hint = TF_DataType.DtInvalid, string name = "") + { + return internal_convert_to_tensor(value, dtype: dtype, name: name, preferred_dtype: dtype_hint, as_ref: false); } public static Tensor convert_to_tensor_or_composite(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "") @@ -87,7 +85,7 @@ namespace Tensorflow public static Tensor internal_convert_to_tensor_or_composite(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "", bool as_ref = false) { - return internal_convert_to_tensor(value, dtype: dtype.as_datatype_enum(), name: name, as_ref: as_ref); + return internal_convert_to_tensor(value, dtype: dtype, name: name, as_ref: as_ref); } /// @@ -117,17 +115,14 @@ namespace Tensorflow var op_desc = graph.NewOperation(node_def.Op, node_def.Name); // Add inputs - if(inputs != null) + foreach (var op_input in inputs) { - foreach (var op_input in inputs) - { - if (op_input is Tensor[] op_inputs) - c_api.TF_AddInputList(op_desc, op_inputs.Select(x => x._as_tf_output()).ToArray(), op_inputs.Length); - else if (op_input is Tensor op_input1) - c_api.TF_AddInput(op_desc, op_input1._as_tf_output()); - else - throw new NotImplementedException("_create_c_op"); - } + if (op_input is Tensor[] op_inputs) + c_api.TF_AddInputList(op_desc, op_inputs.Select(x => x._as_tf_output()).ToArray(), op_inputs.Length); + else if (op_input is Tensor op_input1) + c_api.TF_AddInput(op_desc, op_input1._as_tf_output()); + else + throw new NotImplementedException("_create_c_op"); } var status = new Status(); @@ -142,8 +137,8 @@ namespace Tensorflow var bytes = attr.Value.ToByteArray(); var proto = Marshal.AllocHGlobal(bytes.Length); Marshal.Copy(bytes, 0, proto, bytes.Length); - - c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (uint)bytes.Length, status: status); + uint len = (uint)bytes.Length; + c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); status.Check(true); } @@ -385,8 +380,8 @@ namespace Tensorflow return ret.ToArray(); } - public static Tensor[] internal_convert_n_to_tensor(T[] values, DataType dtype = DataType.DtInvalid, - string name = "", DataType preferred_dtype = DataType.DtInvalid, + public static Tensor[] internal_convert_n_to_tensor(T[] values, TF_DataType dtype = TF_DataType.DtInvalid, + string name = "", TF_DataType preferred_dtype = TF_DataType.DtInvalid, bool as_ref = false) { var ret = new List(); @@ -400,28 +395,30 @@ namespace Tensorflow return ret.ToArray(); } - public static Tensor internal_convert_to_tensor(T value, DataType dtype = DataType.DtInvalid, - string name = "", DataType preferred_dtype = DataType.DtInvalid, + public static Tensor internal_convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, + string name = "", TF_DataType preferred_dtype = TF_DataType.DtInvalid, bool as_ref = false) { - switch (typeof(T).Name) + switch (value) { - case "Tensor": - return value as Tensor; - case "String": - return constant_op.constant(Convert.ToString(value), name); - case "String[]": - return constant_op.constant(value as string[], name); - case "Int32": - return constant_op.constant(Convert.ToInt32(value), name); - case "Single": - return constant_op.constant(Convert.ToSingle(value), name); - case "Double": - return constant_op.constant(Convert.ToDouble(value), name); - case "RefVariable": - return (value as RefVariable)._TensorConversionFunction(as_ref: as_ref); + case Tensor tensor: + return tensor; + case string str: + return constant_op.constant(str, dtype: dtype, name: name); + case string[] strArray: + return constant_op.constant(strArray, dtype: dtype, name: name); + case int intVal: + return constant_op.constant(intVal, dtype: dtype, name: name); + case int[] intArray: + return constant_op.constant(intArray, dtype: dtype, name: name); + case float floatVal: + return constant_op.constant(floatVal, dtype: dtype, name: name); + case double doubleVal: + return constant_op.constant(doubleVal, dtype: dtype, name: name); + case RefVariable varVal: + return varVal._TensorConversionFunction(as_ref: as_ref); default: - throw new NotImplementedException($"internal_convert_to_tensor: Can't convert {typeof(T).Name} to Tensor"); + throw new NotImplementedException($"internal_convert_to_tensor: Can't convert {value.GetType().Name} to Tensor"); } } } diff --git a/test/TensorFlowNET.UnitTest/ConstantTest.cs b/test/TensorFlowNET.UnitTest/ConstantTest.cs index 588e10d1..cbd1578d 100644 --- a/test/TensorFlowNET.UnitTest/ConstantTest.cs +++ b/test/TensorFlowNET.UnitTest/ConstantTest.cs @@ -79,7 +79,7 @@ namespace TensorFlowNET.UnitTest Assert.AreEqual(result.shape[0], 2); Assert.AreEqual(result.shape[1], 3); - Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 2, 1, 1, 3 }, data)); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 2, 1, 1, 1, 3 }, data)); }); } diff --git a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs index b7c33e5b..53fab4eb 100644 --- a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs +++ b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs @@ -17,6 +17,35 @@ namespace TensorFlowNET.UnitTest tf.train.write_graph(sess.graph, "/tmp/my-model", "train.pbtxt"); } + [TestMethod] + public void ImportGraph() + { + var v = tf.Variable(0, name: "my_variable"); + var sess = tf.Session(); + tf.train.write_graph(sess.graph, "/tmp/my-model", "train.pbtxt"); + } + + [TestMethod] + public void SaveSimple() + { + var w1 = tf.Variable(tf.random_normal(new int[] { 2 }), name: "w1"); + var w2 = tf.Variable(tf.random_normal(new int[] { 5 }), name: "w2"); + + var init_op = tf.global_variables_initializer(); + + // Add ops to save and restore all the variables. + var saver = tf.train.Saver(); + + with(tf.Session(), sess => + { + sess.run(init_op); + + // Save the variables to disk. + var save_path = saver.save(sess, "/tmp/model.ckpt"); + Console.WriteLine($"Model saved in path: {save_path}"); + }); + } + [TestMethod] public void Save() {