| @@ -18,7 +18,7 @@ namespace Tensorflow | |||
| public static Tensor divide<T>(Tensor x, T[] y, string name = "") where T : struct | |||
| => x / ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"); | |||
| public static Tensor pow(Tensor x, double y) => gen_math_ops.pow(x, y); | |||
| public static Tensor pow<T1, T2>(T1 x, T2 y) => gen_math_ops.pow(x, y); | |||
| /// <summary> | |||
| /// Computes the sum of elements across dimensions of a tensor. | |||
| @@ -357,7 +357,7 @@ namespace Tensorflow | |||
| if (y.dtype.is_complex()) | |||
| throw new TypeAccessException($"Gradients of complex tensors must set grad_ys (y.dtype = {y.dtype})"); | |||
| var shape = array_ops.shape(y); | |||
| var constant = constant_op.constant(1.0, name: $"grad_ys_{i}"); | |||
| var constant = constant_op.constant(1.0f, name: $"grad_ys_{i}"); | |||
| var fill = gen_array_ops.fill(shape, constant); | |||
| new_grad_ys.Add(fill); | |||
| } | |||
| @@ -41,6 +41,7 @@ namespace Tensorflow | |||
| public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true) | |||
| { | |||
| var ret = new Operation(c_op); | |||
| _add_op(ret); | |||
| var name_key = ret.name.ToLower(); | |||
| if (!_names_in_use.ContainsKey(name_key)) | |||
| @@ -212,6 +212,7 @@ namespace Tensorflow | |||
| public void _add_op(Operation op) | |||
| { | |||
| op._id_value = _next_id(); | |||
| _nodes_by_id[op._id] = op; | |||
| _nodes_by_name[op.name] = op; | |||
| _version = Math.Max(_version, op._id); | |||
| @@ -13,7 +13,7 @@ namespace Tensorflow | |||
| public Graph graph { get; } | |||
| public int _id => _id_value; | |||
| private int _id_value; | |||
| public int _id_value; | |||
| public string type => OpType; | |||
| public Operation op => this; | |||
| @@ -46,8 +46,6 @@ namespace Tensorflow | |||
| _outputs = new Tensor[NumOutputs]; | |||
| for (int i = 0; i < NumOutputs; i++) | |||
| _outputs[i] = new Tensor(this, i, OutputType(i)); | |||
| graph._add_op(this); | |||
| } | |||
| public Operation(Graph g, string opType, string oper_name) | |||
| @@ -100,8 +98,6 @@ namespace Tensorflow | |||
| } | |||
| // This will be set by self.inputs. | |||
| _id_value = graph._next_id(); | |||
| if(op_def == null) | |||
| op_def = g.GetOpDef(node_def.Op); | |||
| @@ -88,8 +88,8 @@ namespace Tensorflow | |||
| public static IEnumerable<(T, T)> zip<T>(NDArray t1, NDArray t2) | |||
| { | |||
| int index = 0; | |||
| yield return(t1.Data<T>(index), t2.Data<T>(index)); | |||
| for (int i = 0; i < t1.size; i++) | |||
| yield return (t1.Data<T>(i), t2.Data<T>(i)); | |||
| } | |||
| public static IEnumerable<(T1, T2)> zip<T1, T2>(IList<T1> t1, IList<T2> t2) | |||
| @@ -59,6 +59,7 @@ namespace Tensorflow | |||
| { | |||
| var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); | |||
| var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); | |||
| switch (subfeed_val) | |||
| { | |||
| case IntPtr pointer: | |||
| @@ -86,6 +87,7 @@ namespace Tensorflow | |||
| Console.WriteLine($"can't handle data type of subfeed_val"); | |||
| throw new NotImplementedException("_run subfeed"); | |||
| } | |||
| feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); | |||
| } | |||
| } | |||
| @@ -45,7 +45,7 @@ Upgraded to TensorFlow 1.13 RC2. | |||
| <ItemGroup> | |||
| <PackageReference Include="Google.Protobuf" Version="3.6.1" /> | |||
| <PackageReference Include="NumSharp" Version="0.7.2" /> | |||
| <PackageReference Include="NumSharp" Version="0.7.3" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| @@ -52,20 +52,45 @@ namespace Tensorflow | |||
| var dataType = ToTFDataType(nd.dtype); | |||
| // shape | |||
| var dims = nd.shape.Select(x => (long)x).ToArray(); | |||
| var nd1 = nd.ravel(); | |||
| switch (nd.dtype.Name) | |||
| { | |||
| case "Int16": | |||
| Marshal.Copy(nd.ravel().Data<short>(), 0, dotHandle, nd.size); | |||
| Marshal.Copy(nd1.Data<short>(), 0, dotHandle, nd.size); | |||
| break; | |||
| case "Int32": | |||
| Marshal.Copy(nd.ravel().Data<int>(), 0, dotHandle, nd.size); | |||
| Marshal.Copy(nd1.Data<int>(), 0, dotHandle, nd.size); | |||
| break; | |||
| case "Single": | |||
| Marshal.Copy(nd.ravel().Data<float>(), 0, dotHandle, nd.size); | |||
| Marshal.Copy(nd1.Data<float>(), 0, dotHandle, nd.size); | |||
| /*if (nd.size > 1) | |||
| { | |||
| var bb = nd.Data<byte>(); | |||
| var bytes = Marshal.AllocHGlobal(bb.Length); | |||
| Marshal.Copy(bb, 0, bytes, bb.Length); | |||
| ulong bytes_len = c_api.TF_StringEncodedSize((ulong)bb.Length); | |||
| var dataTypeByte = ToTFDataType(nd.dtype); | |||
| // shape | |||
| var dims2 = nd.shape.Select(x => (long)x).ToArray(); | |||
| var tfHandle2 = c_api.TF_AllocateTensor(dataTypeByte, | |||
| dims2, | |||
| nd.ndim, | |||
| bytes_len + sizeof(Int64)); | |||
| dotHandle = c_api.TF_TensorData(tfHandle2); | |||
| Marshal.WriteInt64(dotHandle, 0); | |||
| c_api.TF_StringEncode(bytes, (ulong)bb.Length, dotHandle + sizeof(Int64), bytes_len, status); | |||
| return tfHandle2; | |||
| } | |||
| else | |||
| { | |||
| Marshal.Copy(nd1.Data<float>(), 0, dotHandle, nd.size); | |||
| }*/ | |||
| break; | |||
| case "Double": | |||
| Marshal.Copy(nd.ravel().Data<double>(), 0, dotHandle, nd.size); | |||
| Marshal.Copy(nd1.Data<double>(), 0, dotHandle, nd.size); | |||
| break; | |||
| //case "Byte": | |||
| /*var bb = nd.Data<byte>(); | |||
| @@ -119,7 +144,7 @@ namespace Tensorflow | |||
| dims, | |||
| dims.Length, | |||
| dotHandle, | |||
| size, | |||
| (UIntPtr)size, | |||
| deallocator, | |||
| ref deallocator_called); | |||
| @@ -6,83 +6,70 @@ namespace Tensorflow | |||
| { | |||
| public partial class Tensor | |||
| { | |||
| public static Tensor operator +(Tensor x, Tensor y) | |||
| { | |||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", "add", new Tensor[] { x, y }), scope => | |||
| { | |||
| return gen_math_ops.add(x, y, scope); | |||
| }); | |||
| } | |||
| public static Tensor operator +(Tensor x, int y) | |||
| { | |||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", "add", new object[] { x, y }), scope => | |||
| { | |||
| var y1 = ops.convert_to_tensor(y, x.dtype.as_base_dtype(), name: "y"); | |||
| return gen_math_ops.add(x, y1, scope); | |||
| }); | |||
| } | |||
| public static Tensor operator +(Tensor x, Tensor y) => BinaryOpWrapper("add", x, y); | |||
| public static Tensor operator +(Tensor x, int y) => BinaryOpWrapper("add", x, y); | |||
| public static Tensor operator -(Tensor t1) => gen_math_ops.neg(t1); | |||
| public static Tensor operator -(Tensor t1, Tensor t2) => gen_math_ops.sub(t1, t2); | |||
| public static Tensor operator -(Tensor t1, int t2) => gen_math_ops.sub(t1, t2); | |||
| public static Tensor operator -(Tensor t1, double t2) => gen_math_ops.sub(t1, t2); | |||
| public static Tensor operator *(double x, Tensor y) | |||
| { | |||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", "mul", new { x, y }), | |||
| scope => | |||
| { | |||
| var x1 = ops.convert_to_tensor(x, y.dtype.as_base_dtype(), name: "x"); | |||
| return gen_math_ops.mul(x1, y, name: scope); | |||
| }); | |||
| } | |||
| public static Tensor operator -(Tensor x, Tensor y) => BinaryOpWrapper("sub", x, y); | |||
| public static Tensor operator -(Tensor x, int y) => BinaryOpWrapper("sub", x, y); | |||
| public static Tensor operator -(Tensor x, double y) => BinaryOpWrapper("sub", x, y); | |||
| public static Tensor operator *(Tensor x, Tensor y) | |||
| { | |||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", "mul", new Tensor[] { x, y }), scope => | |||
| { | |||
| return gen_math_ops.mul(x, y, name: scope); | |||
| }); | |||
| } | |||
| public static Tensor operator *(float x, Tensor y) => BinaryOpWrapper("mul", x, y); | |||
| public static Tensor operator *(double x, Tensor y) => BinaryOpWrapper("mul", x, y); | |||
| public static Tensor operator *(Tensor x, Tensor y) => BinaryOpWrapper("mul", x, y); | |||
| public static Tensor operator *(Tensor x, int y) => BinaryOpWrapper("mul", x, y); | |||
| public static Tensor operator *(Tensor x, int y) | |||
| { | |||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", "mul", new object[] { x, y }), scope => | |||
| { | |||
| var y1 = ops.convert_to_tensor(y, x.dtype.as_base_dtype(), name: "y"); | |||
| return gen_math_ops.mul(x, y1, name: scope); | |||
| }); | |||
| } | |||
| public static Tensor operator /(Tensor x, Tensor y) => BinaryOpWrapper("truediv", x, y); | |||
| public static Tensor operator /(Tensor x, float y) => BinaryOpWrapper("truediv", x, y); | |||
| public static Tensor operator /(Tensor x, double y) => BinaryOpWrapper("truediv", x, y); | |||
| public static Tensor operator /(Tensor x, Tensor y) | |||
| { | |||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope("truediv/", "truediv", new Tensor[] { x, y }), scope => | |||
| { | |||
| return gen_math_ops.real_div(x, y, scope); | |||
| }); | |||
| } | |||
| public static Tensor operator %(Tensor x, Tensor y) => BinaryOpWrapper("mod", x, y); | |||
| public static Tensor operator /(Tensor x, double y) | |||
| { | |||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope("truediv/", "truediv", new object[] { x, y }), scope => | |||
| { | |||
| var y1 = ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"); | |||
| return gen_math_ops.real_div(x, y1, scope); | |||
| }); | |||
| } | |||
| public static Tensor operator >(Tensor x, int y) => gen_array_ops.greater(x, y); | |||
| public static Tensor operator >(Tensor x, double y) => gen_array_ops.greater(x, y); | |||
| public static Tensor operator <(Tensor x, int y) => gen_array_ops.less(x, y); | |||
| public static Tensor operator <(Tensor x, double y) => gen_array_ops.less(x, y); | |||
| public static Tensor operator %(Tensor x, Tensor y) | |||
| private static Tensor BinaryOpWrapper<Tx, Ty>(string name, Tx x, Ty y) | |||
| { | |||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", "mod", new object[] { x, y }), scope => | |||
| TF_DataType dtype = TF_DataType.DtInvalid; | |||
| if (x is Tensor tl) | |||
| dtype = tl.dtype.as_base_dtype(); | |||
| if( y is Tensor tr) | |||
| dtype = tr.dtype.as_base_dtype(); | |||
| var namescope = new ops.name_scope("", name, new { x, y }); | |||
| return Python.with<ops.name_scope, Tensor>(namescope, scope => | |||
| { | |||
| return gen_math_ops.floor_mod(x, y, scope); | |||
| Tensor result = null; | |||
| var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x"); | |||
| var y1 = ops.convert_to_tensor(y, dtype: dtype, name: "y"); | |||
| switch (name) | |||
| { | |||
| case "add": | |||
| result = gen_math_ops.add(x1, y1, name: scope); | |||
| break; | |||
| case "truediv": | |||
| result = gen_math_ops.real_div(x1, y1, name: scope); | |||
| break; | |||
| case "mul": | |||
| result = gen_math_ops.mul(x1, y1, name: scope); | |||
| break; | |||
| case "sub": | |||
| result = gen_math_ops.sub(x1, y1, name: scope); | |||
| break; | |||
| case "mod": | |||
| result = gen_math_ops.floor_mod(x1, y1, name: scope); | |||
| break; | |||
| default: | |||
| throw new NotImplementedException($"BinaryOpWrapper: {name} - {typeof(Tx).Name}, {typeof(Ty)}"); | |||
| } | |||
| return result; | |||
| }); | |||
| } | |||
| public static Tensor operator >(Tensor x, int y) => gen_array_ops.greater(x, y); | |||
| public static Tensor operator >(Tensor x, double y) => gen_array_ops.greater(x, y); | |||
| public static Tensor operator <(Tensor x, int y) => gen_array_ops.less(x, y); | |||
| public static Tensor operator <(Tensor x, double y) => gen_array_ops.less(x, y); | |||
| } | |||
| } | |||
| @@ -55,7 +55,7 @@ namespace Tensorflow | |||
| /// <param name="deallocator_arg"></param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, ulong len, Deallocator deallocator, ref bool deallocator_arg); | |||
| public static extern IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, UIntPtr len, Deallocator deallocator, ref bool deallocator_arg); | |||
| /// <summary> | |||
| /// Return the number of dimensions that the tensor has. | |||
| @@ -96,19 +96,19 @@ namespace Tensorflow | |||
| if (values.GetType().IsArray) | |||
| nparray = np.array((int[])values, np_dt); | |||
| else | |||
| nparray = (int)values; | |||
| nparray = Convert.ToInt32(values); | |||
| break; | |||
| case "Single": | |||
| if (values.GetType().IsArray) | |||
| nparray = np.array((float[])values, np_dt); | |||
| else | |||
| nparray = (float)values; | |||
| nparray = Convert.ToSingle(values); | |||
| break; | |||
| case "Double": | |||
| nparray = (double)values; | |||
| nparray = Convert.ToDouble(values); | |||
| break; | |||
| case "String": | |||
| nparray = values.ToString(); | |||
| nparray = Convert.ToString(values); | |||
| break; | |||
| default: | |||
| throw new NotImplementedException("make_tensor_proto Not Implemented"); | |||
| @@ -10,41 +10,47 @@ namespace TensorFlowNET.Examples | |||
| /// A linear regression learning algorithm example using TensorFlow library. | |||
| /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/linear_regression.py | |||
| /// </summary> | |||
| public class LinearRegression : IExample | |||
| public class LinearRegression : Python, IExample | |||
| { | |||
| private NumPyRandom rng = np.random; | |||
| public void Run() | |||
| { | |||
| var graph = tf.Graph().as_default(); | |||
| // Parameters | |||
| double learning_rate = 0.01; | |||
| float learning_rate = 0.01f; | |||
| int training_epochs = 1000; | |||
| int display_step = 50; | |||
| int display_step = 1; | |||
| // Training Data | |||
| var train_X = np.array(3.3, 4.4, 5.5, 6.71, 6.93, 4.168, 9.779, 6.182, 7.59, 2.167, | |||
| 7.042, 10.791, 5.313, 7.997, 5.654, 9.27, 3.1); | |||
| var train_Y = np.array(1.7, 2.76, 2.09, 3.19, 1.694, 1.573, 3.366, 2.596, 2.53, 1.221, | |||
| 2.827, 3.465, 1.65, 2.904, 2.42, 2.94, 1.3); | |||
| var train_X = np.array(3.3f, 4.4f, 5.5f, 6.71f, 6.93f, 4.168f, 9.779f, 6.182f, 7.59f, 2.167f, | |||
| 7.042f, 10.791f, 5.313f, 7.997f, 5.654f, 9.27f, 3.1f); | |||
| var train_Y = np.array(1.7f, 2.76f, 2.09f, 3.19f, 1.694f, 1.573f, 3.366f, 2.596f, 2.53f, 1.221f, | |||
| 2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f); | |||
| var n_samples = train_X.shape[0]; | |||
| // tf Graph Input | |||
| var X = tf.placeholder(tf.float64); | |||
| var Y = tf.placeholder(tf.float64); | |||
| var X = tf.placeholder(tf.float32); | |||
| var Y = tf.placeholder(tf.float32); | |||
| // Set model weights | |||
| var W = tf.Variable(rng.randn<double>(), name: "weight"); | |||
| var b = tf.Variable(rng.randn<double>(), name: "bias"); | |||
| //var rnd1 = rng.randn<float>(); | |||
| //var rnd2 = rng.randn<float>(); | |||
| var W = tf.Variable(-0.06f, name: "weight"); | |||
| var b = tf.Variable(-0.73f, name: "bias"); | |||
| var mul = tf.multiply(X, W); | |||
| var pred = tf.add(mul, b); | |||
| // Mean squared error | |||
| var sub = pred - Y; | |||
| var pow = tf.pow(sub, 2); | |||
| var pow = tf.pow(sub, 2.0f); | |||
| var reduce = tf.reduce_sum(pow); | |||
| var cost = reduce / (2d * n_samples); | |||
| var cost = reduce / (2.0f * n_samples); | |||
| // import graph | |||
| // radient descent | |||
| // Note, minimize() knows to modify W and b because Variable objects are trainable=True by default | |||
| @@ -55,7 +61,7 @@ namespace TensorFlowNET.Examples | |||
| var init = tf.global_variables_initializer(); | |||
| // Start training | |||
| Python.with<Session>(tf.Session(), sess => | |||
| Python.with<Session>(tf.Session(graph), sess => | |||
| { | |||
| // Run the initializer | |||
| sess.run(init); | |||
| @@ -63,11 +69,12 @@ namespace TensorFlowNET.Examples | |||
| // Fit all training data | |||
| for (int epoch = 0; epoch < training_epochs; epoch++) | |||
| { | |||
| foreach (var (x, y) in Python.zip<double>(train_X, train_Y)) | |||
| foreach (var (x, y) in zip<float>(train_X, train_Y)) | |||
| { | |||
| sess.run(optimizer, | |||
| new FeedItem(X, x), | |||
| new FeedItem(Y, y)); | |||
| var w = sess.run(W); | |||
| } | |||
| // Display logs per epoch step | |||
| @@ -6,7 +6,7 @@ | |||
| </PropertyGroup> | |||
| <ItemGroup> | |||
| <PackageReference Include="NumSharp" Version="0.7.2" /> | |||
| <PackageReference Include="NumSharp" Version="0.7.3" /> | |||
| <PackageReference Include="TensorFlow.NET" Version="0.3.0" /> | |||
| </ItemGroup> | |||
| @@ -12,6 +12,7 @@ namespace TensorFlowNET.UnitTest | |||
| [TestMethod] | |||
| public void Gradients() | |||
| { | |||
| var graph = tf.Graph().as_default(); | |||
| var a = tf.constant(0.0); | |||
| var b = 2.0 * a; | |||
| Assert.AreEqual(b.name, "mul:0"); | |||
| @@ -19,7 +19,7 @@ | |||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.9.0" /> | |||
| <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | |||
| <PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | |||
| <PackageReference Include="NumSharp" Version="0.7.2" /> | |||
| <PackageReference Include="NumSharp" Version="0.7.3" /> | |||
| <PackageReference Include="TensorFlow.NET" Version="0.3.0" /> | |||
| </ItemGroup> | |||
| @@ -60,7 +60,7 @@ namespace TensorFlowNET.UnitTest | |||
| EXPECT_EQ((int)tensor.shape[0], nd.shape[0]); | |||
| EXPECT_EQ((int)tensor.shape[1], nd.shape[1]); | |||
| EXPECT_EQ(tensor.bytesize, (ulong)nd.size * sizeof(float)); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), array)); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), new float[] { 1, 4, 2, 5, 3, 6 })); | |||
| } | |||
| /// <summary> | |||
| @@ -10,7 +10,6 @@ namespace TensorFlowNET.UnitTest | |||
| [TestClass] | |||
| public class TrainSaverTest : Python | |||
| { | |||
| [TestMethod] | |||
| public void ExportGraph() | |||
| { | |||
| var v = tf.Variable(0, name: "my_variable"); | |||
| @@ -18,7 +17,6 @@ namespace TensorFlowNET.UnitTest | |||
| tf.train.write_graph(sess.graph, "/tmp/my-model", "train1.pbtxt"); | |||
| } | |||
| [TestMethod] | |||
| public void ImportGraph() | |||
| { | |||
| with<Session>(tf.Session(), sess => | |||
| @@ -27,7 +25,6 @@ namespace TensorFlowNET.UnitTest | |||
| }); | |||
| } | |||
| [TestMethod] | |||
| public void ImportSavedModel() | |||
| { | |||
| with<Session>(Session.LoadFromSavedModel("mobilenet"), sess => | |||
| @@ -36,14 +33,12 @@ namespace TensorFlowNET.UnitTest | |||
| }); | |||
| } | |||
| [TestMethod] | |||
| public void ImportGraphDefFromPbFile() | |||
| { | |||
| var g = new Graph(); | |||
| var status = g.Import("mobilenet/saved_model.pb"); | |||
| } | |||
| [TestMethod] | |||
| public void Save1() | |||
| { | |||
| var w1 = tf.Variable(0, name: "save1"); | |||
| @@ -63,7 +58,6 @@ namespace TensorFlowNET.UnitTest | |||
| }); | |||
| } | |||
| [TestMethod] | |||
| public void Save2() | |||
| { | |||
| var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer); | |||