| @@ -0,0 +1,26 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public static partial class tf | |||
| { | |||
| /// <summary> | |||
| /// Outputs random values from a normal distribution. | |||
| /// </summary> | |||
| /// <param name="shape"></param> | |||
| /// <param name="mean"></param> | |||
| /// <param name="stddev"></param> | |||
| /// <param name="dtype"></param> | |||
| /// <param name="seed"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public static Tensor random_normal(int[] shape, | |||
| float mean = 0.0f, | |||
| float stddev = 1.0f, | |||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
| int? seed = null, | |||
| string name = "") => random_ops.random_normal(shape, mean, stddev, dtype, seed, name); | |||
| } | |||
| } | |||
| @@ -0,0 +1,14 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class random_seed | |||
| { | |||
| public static (int?, int?) get_seed(int? op_seed = null) | |||
| { | |||
| return (null, null); | |||
| } | |||
| } | |||
| } | |||
| @@ -141,10 +141,9 @@ namespace Tensorflow | |||
| } | |||
| if (String.IsNullOrEmpty(name)) | |||
| { | |||
| name = op_type; | |||
| } | |||
| // If a names ends with a '/' it is a "name scope" and we use it as-is, | |||
| // after removing the trailing '/'. | |||
| name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name); | |||
| var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); | |||
| @@ -14,7 +14,7 @@ namespace Tensorflow | |||
| { | |||
| public Operation _apply_op_helper(string op_type_name, string name = "", dynamic args = null) | |||
| { | |||
| var keywords = ConvertToDict(args); | |||
| Dictionary<string, object> keywords = ConvertToDict(args); | |||
| var g = ops.get_default_graph(); | |||
| var op_def = g.GetOpDef(op_type_name); | |||
| @@ -42,7 +42,8 @@ namespace Tensorflow | |||
| var attrs = new Dictionary<string, object>(); | |||
| var inputs = new List<Tensor>(); | |||
| var input_types = new List<TF_DataType>(); | |||
| dynamic values = null; | |||
| return Python.with<ops.name_scope, Operation>(new ops.name_scope(name), scope => | |||
| { | |||
| var inferred_from = new Dictionary<string, object>(); | |||
| @@ -53,7 +54,17 @@ namespace Tensorflow | |||
| foreach (var input_arg in op_def.InputArg) | |||
| { | |||
| var input_name = input_arg.Name; | |||
| var values = keywords[input_name]; | |||
| if (keywords.ContainsKey(input_name)) | |||
| values = keywords[input_name]; | |||
| else if (keywords.ContainsKey(input_name + "_")) | |||
| { | |||
| input_name += "_"; | |||
| values = keywords[input_name]; | |||
| } | |||
| else | |||
| throw new TypeError("No argument for input " + input_name); | |||
| // Goals: | |||
| // * Convert values to Tensors if it contains constants. | |||
| // * Verify that values is a list if that matches the input_arg's | |||
| @@ -92,8 +103,8 @@ namespace Tensorflow | |||
| values = ops.internal_convert_n_to_tensor(values, | |||
| name: input_arg.Name, | |||
| dtype: dtype, | |||
| preferred_dtype: default_dtype, | |||
| dtype: dtype.as_tf_dtype(), | |||
| preferred_dtype: default_dtype.as_tf_dtype(), | |||
| as_ref: input_arg.IsRef); | |||
| } | |||
| else | |||
| @@ -107,9 +118,9 @@ namespace Tensorflow | |||
| values = ops.internal_convert_to_tensor(values, | |||
| name: input_name, | |||
| dtype: dtype, | |||
| dtype: dtype.as_tf_dtype(), | |||
| as_ref: input_arg.IsRef, | |||
| preferred_dtype: default_dtype); | |||
| preferred_dtype: default_dtype.as_tf_dtype()); | |||
| //if (!String.IsNullOrEmpty(input_arg.TypeAttr)) | |||
| //attrs[input_arg.TypeAttr] = values.dtype; | |||
| @@ -7,6 +7,8 @@ namespace Tensorflow | |||
| { | |||
| public class array_ops | |||
| { | |||
| public static Tensor placeholder_with_default<T>(T input, int[] shape, string name = "") => gen_array_ops.placeholder_with_default(input, shape, name); | |||
| public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "") | |||
| { | |||
| dtype = dtype.as_base_dtype(); | |||
| @@ -35,13 +37,13 @@ namespace Tensorflow | |||
| var nd = np.zeros<T>(shape); | |||
| if (shape.Size < 1000) | |||
| { | |||
| return constant_op.constant(nd, name); | |||
| return constant_op.constant(nd, name: name); | |||
| } | |||
| else | |||
| { | |||
| tShape = constant_op._tensor_shape_tensor_conversion_function(shape.as_shape()); | |||
| var c = constant_op.constant(0); | |||
| return gen_array_ops.fill(tShape, c, name); | |||
| return gen_array_ops.fill(tShape, c, name: name); | |||
| } | |||
| } | |||
| @@ -99,7 +101,7 @@ namespace Tensorflow | |||
| if (optimize && input_shape.is_fully_defined()) | |||
| { | |||
| var nd = np.array(input_tensor.shape, out_type.as_numpy_datatype()); | |||
| return constant_op.constant(nd, name); | |||
| return constant_op.constant(nd, name: name); | |||
| } | |||
| } | |||
| @@ -122,7 +124,7 @@ namespace Tensorflow | |||
| if (input_shape.is_fully_defined()) | |||
| { | |||
| var nd = np.array(input_tensor.shape, out_type.as_numpy_datatype()); | |||
| return constant_op.constant(nd, name); | |||
| return constant_op.constant(nd, name: name); | |||
| } | |||
| } | |||
| @@ -113,7 +113,7 @@ namespace Tensorflow | |||
| /// <param name="shape"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public static Tensor placeholder_with_default<T>(T input, TensorShape shape, string name = "") | |||
| public static Tensor placeholder_with_default<T>(T input, int[] shape, string name = "") | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("PlaceholderWithDefault", name, new { input, shape, name }); | |||
| return _op.outputs[0]; | |||
| @@ -0,0 +1,33 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class gen_random_ops | |||
| { | |||
| public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | |||
| /// <summary> | |||
| /// Outputs random values from a normal distribution. | |||
| /// </summary> | |||
| /// <param name="shape"></param> | |||
| /// <param name="dtype"></param> | |||
| /// <param name="seed"></param> | |||
| /// <param name="seed2"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public static Tensor random_standard_normal(Tensor shape, TF_DataType dtype = TF_DataType.DtInvalid, int? seed = null, int? seed2 = null, string name = "") | |||
| { | |||
| if (!seed.HasValue) | |||
| seed = 0; | |||
| if (!seed2.HasValue) | |||
| seed2 = 0; | |||
| var _op = _op_def_lib._apply_op_helper("RandomStandardNormal", name: name, | |||
| args: new { shape, dtype, seed, seed2 }); | |||
| return _op.outputs[0]; | |||
| } | |||
| } | |||
| } | |||
| @@ -6,6 +6,8 @@ namespace Tensorflow | |||
| { | |||
| public class math_ops | |||
| { | |||
| public static Tensor add(Tensor x, Tensor y, string name = "") => gen_math_ops.add(x, y, name); | |||
| public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = "") | |||
| { | |||
| var base_type = dtype.as_base_dtype(); | |||
| @@ -0,0 +1,35 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class random_ops | |||
| { | |||
| public static Tensor random_normal(int[] shape, | |||
| float mean = 0.0f, | |||
| float stddev = 1.0f, | |||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
| int? seed = null, | |||
| string name = "") | |||
| { | |||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "random_normal", new object[] { shape, mean, stddev }), scope => | |||
| { | |||
| var shape_tensor = _ShapeTensor(shape); | |||
| var mean_tensor = ops.convert_to_tensor(mean, dtype: dtype, name: "mean"); | |||
| var stddev_tensor = ops.convert_to_tensor(stddev, dtype: dtype, name = "stddev"); | |||
| var (seed1, seed2) = random_seed.get_seed(seed); | |||
| var rnd = gen_random_ops.random_standard_normal(shape_tensor, dtype: dtype, seed: seed1, seed2: seed2); | |||
| var mul = rnd * stddev_tensor; | |||
| var value = math_ops.add(mul, mean_tensor, name: name); | |||
| return value; | |||
| }); | |||
| } | |||
| private static Tensor _ShapeTensor(int[] shape) | |||
| { | |||
| return ops.convert_to_tensor(shape, name: "shape"); | |||
| } | |||
| } | |||
| } | |||
| @@ -55,11 +55,19 @@ namespace Tensorflow | |||
| var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); | |||
| switch(subfeed.Value) | |||
| { | |||
| case float floatVal: | |||
| feed_dict_tensor[subfeed_t] = (NDArray)floatVal; | |||
| break; | |||
| case int intVal: | |||
| feed_dict_tensor[subfeed_t] = (NDArray)intVal; | |||
| break; | |||
| case string str: | |||
| feed_dict_tensor[subfeed_t] = np.array(str); | |||
| feed_map[subfeed_t.name] = new Tuple<object, object>(subfeed_t, subfeed.Value); | |||
| feed_dict_tensor[subfeed_t] = (NDArray)str; | |||
| break; | |||
| default: | |||
| throw new NotImplementedException("_run subfeed"); | |||
| } | |||
| feed_map[subfeed_t.name] = new Tuple<object, object>(subfeed_t, subfeed.Value); | |||
| } | |||
| } | |||
| @@ -19,7 +19,12 @@ namespace Tensorflow | |||
| /// <param name="name">Optional name for the tensor.</param> | |||
| /// <param name="verify_shape">Boolean that enables verification of a shape of values.</param> | |||
| /// <returns></returns> | |||
| public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false) | |||
| public static Tensor constant(object value, TF_DataType dtype = TF_DataType.DtInvalid, int[] shape = null, string name = "Const") | |||
| { | |||
| return _constant_impl(value, dtype, shape, name, verify_shape: false, allow_broadcast: true); | |||
| } | |||
| private static Tensor _constant_impl(object value, TF_DataType dtype, int[] shape, string name, bool verify_shape, bool allow_broadcast) | |||
| { | |||
| if (tf.context.executing_eagerly()) | |||
| { | |||
| @@ -27,13 +32,13 @@ namespace Tensorflow | |||
| } | |||
| Graph g = ops.get_default_graph(); | |||
| var tensor_pb = tensor_util.make_tensor_proto(nd, verify_shape); | |||
| var tensor_value = new AttrValue | |||
| { | |||
| Type = tensor_pb.Dtype, | |||
| Tensor = tensor_pb | |||
| }; | |||
| var tensor_value = new AttrValue(); | |||
| tensor_value.Tensor = tensor_util.make_tensor_proto(value, | |||
| dtype: dtype, | |||
| shape: shape, | |||
| verify_shape: verify_shape, | |||
| allow_broadcast: allow_broadcast); | |||
| var dtype_value = new AttrValue | |||
| { | |||
| Type = tensor_value.Tensor.Dtype, | |||
| @@ -44,8 +49,8 @@ namespace Tensorflow | |||
| attrs["dtype"] = dtype_value; | |||
| var op = g.create_op("Const", | |||
| null, | |||
| new TF_DataType[] { (TF_DataType)dtype_value.Type }, | |||
| new Tensor[0], | |||
| new TF_DataType[] { dtype_value.Type.as_tf_dtype() }, | |||
| attrs: attrs, | |||
| name: name); | |||
| @@ -81,7 +86,7 @@ namespace Tensorflow | |||
| if (string.IsNullOrEmpty(name)) | |||
| name = "shape_as_tensor"; | |||
| return constant_op.constant(s_list, name); | |||
| return constant_op.constant(s_list, name: name); | |||
| } | |||
| } | |||
| } | |||
| @@ -83,6 +83,20 @@ namespace Tensorflow | |||
| type; | |||
| } | |||
| public static TF_DataType as_tf_dtype(this DataType type) | |||
| { | |||
| TF_DataType dtype = TF_DataType.DtInvalid; | |||
| switch (type) | |||
| { | |||
| default: | |||
| Enum.TryParse(((int)type).ToString(), out dtype); | |||
| break; | |||
| } | |||
| return dtype; | |||
| } | |||
| public static TF_DataType as_ref(this TF_DataType type) | |||
| { | |||
| return (int)type < 100 ? | |||
| @@ -10,33 +10,166 @@ namespace Tensorflow | |||
| { | |||
| public static class tensor_util | |||
| { | |||
| public static TensorProto make_tensor_proto(NDArray nd, bool verify_shape = false) | |||
| public static TF_DataType[] _TENSOR_CONTENT_TYPES = | |||
| { | |||
| var shape = nd.Storage.Shape; | |||
| TF_DataType.TF_FLOAT, TF_DataType.TF_DOUBLE, TF_DataType.TF_INT32, TF_DataType.TF_UINT8, TF_DataType.TF_INT16, | |||
| TF_DataType.TF_INT8, TF_DataType.TF_INT64, TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16, | |||
| TF_DataType.TF_QUINT16, TF_DataType.TF_QINT32, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64 | |||
| }; | |||
| /// <summary> | |||
| /// Create a TensorProto. | |||
| /// </summary> | |||
| /// <param name="values"></param> | |||
| /// <param name="dtype"></param> | |||
| /// <param name="shape"></param> | |||
| /// <param name="verify_shape"></param> | |||
| /// <param name="allow_broadcast"></param> | |||
| /// <returns></returns> | |||
| public static TensorProto make_tensor_proto(object values, TF_DataType dtype = TF_DataType.DtInvalid, int[] shape = null, bool verify_shape = false, bool allow_broadcast = false) | |||
| { | |||
| if (allow_broadcast && verify_shape) | |||
| throw new ValueError("allow_broadcast and verify_shape are not both allowed."); | |||
| if (values is TensorProto tp) | |||
| return tp; | |||
| if (dtype != TF_DataType.DtInvalid) | |||
| ; | |||
| bool is_quantized = new TF_DataType[] | |||
| { | |||
| TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QUINT16, | |||
| TF_DataType.TF_QINT32 | |||
| }.Contains(dtype); | |||
| // We first convert value to a numpy array or scalar. | |||
| NDArray nparray = null; | |||
| if (values is NDArray nd) | |||
| { | |||
| nparray = nd; | |||
| } | |||
| else | |||
| { | |||
| if (values == null) | |||
| throw new ValueError("None values not supported."); | |||
| switch (values) | |||
| { | |||
| /*case bool boolVal: | |||
| nparray = boolVal; | |||
| break;*/ | |||
| case int intVal: | |||
| nparray = intVal; | |||
| break; | |||
| case int[] intVals: | |||
| nparray = np.array(intVals); | |||
| break; | |||
| case float floatVal: | |||
| nparray = floatVal; | |||
| break; | |||
| case double doubleVal: | |||
| nparray = doubleVal; | |||
| break; | |||
| case string strVal: | |||
| nparray = strVal; | |||
| break; | |||
| default: | |||
| throw new Exception("make_tensor_proto Not Implemented"); | |||
| } | |||
| } | |||
| var numpy_dtype = dtypes.as_dtype(nparray.dtype); | |||
| if (numpy_dtype == TF_DataType.DtInvalid) | |||
| throw new TypeError($"Unrecognized data type: {nparray.dtype}"); | |||
| // If dtype was specified and is a quantized type, we convert | |||
| // numpy_dtype back into the quantized version. | |||
| if (is_quantized) | |||
| numpy_dtype = dtype; | |||
| bool is_same_size = false; | |||
| int shape_size = 0; | |||
| // If shape is not given, get the shape from the numpy array. | |||
| if (shape == null) | |||
| { | |||
| shape = nparray.shape; | |||
| is_same_size = true; | |||
| shape_size = nparray.size; | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException("make_tensor_proto shape not implemented"); | |||
| } | |||
| var numpy_dtype = dtypes.as_dtype(nd.dtype); | |||
| var tensor_proto = new tensor_pb2.TensorProto | |||
| { | |||
| Dtype = numpy_dtype.as_datatype_enum(), | |||
| TensorShape = shape.reshape(nd.shape).as_proto() | |||
| TensorShape = tensor_util.as_shape(shape) | |||
| }; | |||
| switch (nd.dtype.Name) | |||
| if (is_same_size && _TENSOR_CONTENT_TYPES.Contains(numpy_dtype) && shape_size > 1) | |||
| { | |||
| var bytes = new List<byte>(); | |||
| var nd2 = nparray.ravel(); | |||
| switch (nparray.dtype.Name) | |||
| { | |||
| case "Int32": | |||
| nd2.Data<int>().Select(x => | |||
| { | |||
| bytes.AddRange(BitConverter.GetBytes(x)); | |||
| return x; | |||
| }).ToArray(); | |||
| break; | |||
| case "Single": | |||
| nd2.Data<float>().Select(x => | |||
| { | |||
| bytes.AddRange(BitConverter.GetBytes(x)); | |||
| return x; | |||
| }).ToArray(); | |||
| break; | |||
| case "Double": | |||
| nd2.Data<double>().Select(x => | |||
| { | |||
| bytes.AddRange(BitConverter.GetBytes(x)); | |||
| return x; | |||
| }).ToArray(); | |||
| break; | |||
| default: | |||
| throw new Exception("make_tensor_proto Not Implemented"); | |||
| } | |||
| tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes.ToArray()); | |||
| return tensor_proto; | |||
| } | |||
| if (numpy_dtype == TF_DataType.TF_STRING && !(values is NDArray) && values is string str) | |||
| { | |||
| tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFromUtf8(str)); | |||
| return tensor_proto; | |||
| } | |||
| var proto_values = nparray.ravel(); | |||
| switch (nparray.dtype.Name) | |||
| { | |||
| case "Bool": | |||
| tensor_proto.BoolVal.AddRange(proto_values.Data<bool>()); | |||
| break; | |||
| case "Int32": | |||
| tensor_proto.IntVal.AddRange(nd.Data<int>()); | |||
| tensor_proto.IntVal.AddRange(proto_values.Data<int>()); | |||
| break; | |||
| case "Single": | |||
| tensor_proto.FloatVal.AddRange(nd.Data<float>()); | |||
| tensor_proto.FloatVal.AddRange(proto_values.Data<float>()); | |||
| break; | |||
| case "Double": | |||
| tensor_proto.DoubleVal.AddRange(nd.Data<double>()); | |||
| tensor_proto.DoubleVal.AddRange(proto_values.Data<double>()); | |||
| break; | |||
| case "String": | |||
| tensor_proto.StringVal.AddRange(nd.Data<string>().Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x))); | |||
| tensor_proto.StringVal.AddRange(proto_values.Data<string>().Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x.ToString()))); | |||
| break; | |||
| default: | |||
| throw new Exception("Not Implemented"); | |||
| throw new Exception("make_tensor_proto Not Implemented"); | |||
| } | |||
| return tensor_proto; | |||
| @@ -73,14 +206,24 @@ namespace Tensorflow | |||
| return nd; | |||
| } | |||
| public static TensorShapeProto as_shape(long[] dims) | |||
| public static TensorShapeProto as_shape<T>(T[] dims) | |||
| { | |||
| TensorShapeProto shape = new TensorShapeProto(); | |||
| for (int i = 0; i < dims.Length; i++) | |||
| { | |||
| var dim = new TensorShapeProto.Types.Dim(); | |||
| dim.Size = dims[i]; | |||
| switch(dims[i]) | |||
| { | |||
| case int n: | |||
| dim.Size = n; | |||
| break; | |||
| case long l: | |||
| dim.Size = l; | |||
| break; | |||
| default: | |||
| throw new NotImplementedException("as_shape Not Implemented"); | |||
| } | |||
| dim.Name = $"dim_{i}"; | |||
| shape.Dim.Add(dim); | |||
| @@ -7,14 +7,8 @@ namespace Tensorflow | |||
| { | |||
| public static partial class tf | |||
| { | |||
| public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false) | |||
| { | |||
| return constant_op.constant(nd, name, verify_shape); | |||
| } | |||
| public static Tensor constant(NDArray nd, string name = "Const") => constant_op.constant(nd, name: name); | |||
| public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "") | |||
| { | |||
| return array_ops.zeros(shape, dtype, name); | |||
| } | |||
| public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "") => array_ops.zeros(shape, dtype, name); | |||
| } | |||
| } | |||
| @@ -84,8 +84,8 @@ namespace Tensorflow | |||
| name = scope; | |||
| // Add a placeholder string tensor for the filename. | |||
| var filename_tensor = gen_array_ops.placeholder_with_default( string.IsNullOrEmpty(filename) ? "model" : filename, shape: new TensorShape(), name: "filename"); | |||
| filename_tensor = gen_array_ops.placeholder_with_default(filename_tensor, shape: new TensorShape(), name: "Const"); | |||
| var filename_tensor = array_ops.placeholder_with_default(string.IsNullOrEmpty(filename) ? "model" : filename, shape: new int[0], name: "filename"); | |||
| filename_tensor = gen_array_ops.placeholder_with_default(filename_tensor, shape: new int[0], name: "Const"); | |||
| // Keep the name "Const" for backwards compatibility. | |||
| // Add the save ops. | |||
| @@ -68,16 +68,14 @@ namespace Tensorflow | |||
| /// <param name="dtype"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public static Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "") | |||
| public static Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "", TF_DataType preferred_dtype = TF_DataType.DtInvalid) | |||
| { | |||
| switch (value) | |||
| { | |||
| case Tensor val: | |||
| return val; | |||
| default: | |||
| var nd = tensor_util.convert_to_numpy_ndarray(value); | |||
| return constant_op.constant(nd, name); | |||
| } | |||
| return convert_to_tensor_v2(value, dtype, preferred_dtype, name); | |||
| } | |||
| public static Tensor convert_to_tensor_v2(object value, TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype_hint = TF_DataType.DtInvalid, string name = "") | |||
| { | |||
| return internal_convert_to_tensor(value, dtype: dtype, name: name, preferred_dtype: dtype_hint, as_ref: false); | |||
| } | |||
| public static Tensor convert_to_tensor_or_composite(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "") | |||
| @@ -87,7 +85,7 @@ namespace Tensorflow | |||
| public static Tensor internal_convert_to_tensor_or_composite(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "", bool as_ref = false) | |||
| { | |||
| return internal_convert_to_tensor<Tensor>(value, dtype: dtype.as_datatype_enum(), name: name, as_ref: as_ref); | |||
| return internal_convert_to_tensor(value, dtype: dtype, name: name, as_ref: as_ref); | |||
| } | |||
| /// <summary> | |||
| @@ -117,17 +115,14 @@ namespace Tensorflow | |||
| var op_desc = graph.NewOperation(node_def.Op, node_def.Name); | |||
| // Add inputs | |||
| if(inputs != null) | |||
| foreach (var op_input in inputs) | |||
| { | |||
| foreach (var op_input in inputs) | |||
| { | |||
| if (op_input is Tensor[] op_inputs) | |||
| c_api.TF_AddInputList(op_desc, op_inputs.Select(x => x._as_tf_output()).ToArray(), op_inputs.Length); | |||
| else if (op_input is Tensor op_input1) | |||
| c_api.TF_AddInput(op_desc, op_input1._as_tf_output()); | |||
| else | |||
| throw new NotImplementedException("_create_c_op"); | |||
| } | |||
| if (op_input is Tensor[] op_inputs) | |||
| c_api.TF_AddInputList(op_desc, op_inputs.Select(x => x._as_tf_output()).ToArray(), op_inputs.Length); | |||
| else if (op_input is Tensor op_input1) | |||
| c_api.TF_AddInput(op_desc, op_input1._as_tf_output()); | |||
| else | |||
| throw new NotImplementedException("_create_c_op"); | |||
| } | |||
| var status = new Status(); | |||
| @@ -142,8 +137,8 @@ namespace Tensorflow | |||
| var bytes = attr.Value.ToByteArray(); | |||
| var proto = Marshal.AllocHGlobal(bytes.Length); | |||
| Marshal.Copy(bytes, 0, proto, bytes.Length); | |||
| c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (uint)bytes.Length, status: status); | |||
| uint len = (uint)bytes.Length; | |||
| c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); | |||
| status.Check(true); | |||
| } | |||
| @@ -385,8 +380,8 @@ namespace Tensorflow | |||
| return ret.ToArray(); | |||
| } | |||
| public static Tensor[] internal_convert_n_to_tensor<T>(T[] values, DataType dtype = DataType.DtInvalid, | |||
| string name = "", DataType preferred_dtype = DataType.DtInvalid, | |||
| public static Tensor[] internal_convert_n_to_tensor<T>(T[] values, TF_DataType dtype = TF_DataType.DtInvalid, | |||
| string name = "", TF_DataType preferred_dtype = TF_DataType.DtInvalid, | |||
| bool as_ref = false) | |||
| { | |||
| var ret = new List<Tensor>(); | |||
| @@ -400,28 +395,30 @@ namespace Tensorflow | |||
| return ret.ToArray(); | |||
| } | |||
| public static Tensor internal_convert_to_tensor<T>(T value, DataType dtype = DataType.DtInvalid, | |||
| string name = "", DataType preferred_dtype = DataType.DtInvalid, | |||
| public static Tensor internal_convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, | |||
| string name = "", TF_DataType preferred_dtype = TF_DataType.DtInvalid, | |||
| bool as_ref = false) | |||
| { | |||
| switch (typeof(T).Name) | |||
| switch (value) | |||
| { | |||
| case "Tensor": | |||
| return value as Tensor; | |||
| case "String": | |||
| return constant_op.constant(Convert.ToString(value), name); | |||
| case "String[]": | |||
| return constant_op.constant(value as string[], name); | |||
| case "Int32": | |||
| return constant_op.constant(Convert.ToInt32(value), name); | |||
| case "Single": | |||
| return constant_op.constant(Convert.ToSingle(value), name); | |||
| case "Double": | |||
| return constant_op.constant(Convert.ToDouble(value), name); | |||
| case "RefVariable": | |||
| return (value as RefVariable)._TensorConversionFunction(as_ref: as_ref); | |||
| case Tensor tensor: | |||
| return tensor; | |||
| case string str: | |||
| return constant_op.constant(str, dtype: dtype, name: name); | |||
| case string[] strArray: | |||
| return constant_op.constant(strArray, dtype: dtype, name: name); | |||
| case int intVal: | |||
| return constant_op.constant(intVal, dtype: dtype, name: name); | |||
| case int[] intArray: | |||
| return constant_op.constant(intArray, dtype: dtype, name: name); | |||
| case float floatVal: | |||
| return constant_op.constant(floatVal, dtype: dtype, name: name); | |||
| case double doubleVal: | |||
| return constant_op.constant(doubleVal, dtype: dtype, name: name); | |||
| case RefVariable varVal: | |||
| return varVal._TensorConversionFunction(as_ref: as_ref); | |||
| default: | |||
| throw new NotImplementedException($"internal_convert_to_tensor: Can't convert {typeof(T).Name} to Tensor"); | |||
| throw new NotImplementedException($"internal_convert_to_tensor: Can't convert {value.GetType().Name} to Tensor"); | |||
| } | |||
| } | |||
| } | |||
| @@ -79,7 +79,7 @@ namespace TensorFlowNET.UnitTest | |||
| Assert.AreEqual(result.shape[0], 2); | |||
| Assert.AreEqual(result.shape[1], 3); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 2, 1, 1, 3 }, data)); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 2, 1, 1, 1, 3 }, data)); | |||
| }); | |||
| } | |||
| @@ -17,6 +17,35 @@ namespace TensorFlowNET.UnitTest | |||
| tf.train.write_graph(sess.graph, "/tmp/my-model", "train.pbtxt"); | |||
| } | |||
| [TestMethod] | |||
| public void ImportGraph() | |||
| { | |||
| var v = tf.Variable(0, name: "my_variable"); | |||
| var sess = tf.Session(); | |||
| tf.train.write_graph(sess.graph, "/tmp/my-model", "train.pbtxt"); | |||
| } | |||
| [TestMethod] | |||
| public void SaveSimple() | |||
| { | |||
| var w1 = tf.Variable(tf.random_normal(new int[] { 2 }), name: "w1"); | |||
| var w2 = tf.Variable(tf.random_normal(new int[] { 5 }), name: "w2"); | |||
| var init_op = tf.global_variables_initializer(); | |||
| // Add ops to save and restore all the variables. | |||
| var saver = tf.train.Saver(); | |||
| with<Session>(tf.Session(), sess => | |||
| { | |||
| sess.run(init_op); | |||
| // Save the variables to disk. | |||
| var save_path = saver.save(sess, "/tmp/model.ckpt"); | |||
| Console.WriteLine($"Model saved in path: {save_path}"); | |||
| }); | |||
| } | |||
| [TestMethod] | |||
| public void Save() | |||
| { | |||