| @@ -18,7 +18,7 @@ namespace Tensorflow | |||||
| public static Tensor divide<T>(Tensor x, T[] y, string name = "") where T : struct | public static Tensor divide<T>(Tensor x, T[] y, string name = "") where T : struct | ||||
| => x / ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"); | => x / ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"); | ||||
| public static Tensor pow(Tensor x, double y) => gen_math_ops.pow(x, y); | |||||
| public static Tensor pow<T1, T2>(T1 x, T2 y) => gen_math_ops.pow(x, y); | |||||
| /// <summary> | /// <summary> | ||||
| /// Computes the sum of elements across dimensions of a tensor. | /// Computes the sum of elements across dimensions of a tensor. | ||||
| @@ -357,7 +357,7 @@ namespace Tensorflow | |||||
| if (y.dtype.is_complex()) | if (y.dtype.is_complex()) | ||||
| throw new TypeAccessException($"Gradients of complex tensors must set grad_ys (y.dtype = {y.dtype})"); | throw new TypeAccessException($"Gradients of complex tensors must set grad_ys (y.dtype = {y.dtype})"); | ||||
| var shape = array_ops.shape(y); | var shape = array_ops.shape(y); | ||||
| var constant = constant_op.constant(1.0, name: $"grad_ys_{i}"); | |||||
| var constant = constant_op.constant(1.0f, name: $"grad_ys_{i}"); | |||||
| var fill = gen_array_ops.fill(shape, constant); | var fill = gen_array_ops.fill(shape, constant); | ||||
| new_grad_ys.Add(fill); | new_grad_ys.Add(fill); | ||||
| } | } | ||||
| @@ -41,6 +41,7 @@ namespace Tensorflow | |||||
| public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true) | public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true) | ||||
| { | { | ||||
| var ret = new Operation(c_op); | var ret = new Operation(c_op); | ||||
| _add_op(ret); | |||||
| var name_key = ret.name.ToLower(); | var name_key = ret.name.ToLower(); | ||||
| if (!_names_in_use.ContainsKey(name_key)) | if (!_names_in_use.ContainsKey(name_key)) | ||||
| @@ -212,6 +212,7 @@ namespace Tensorflow | |||||
| public void _add_op(Operation op) | public void _add_op(Operation op) | ||||
| { | { | ||||
| op._id_value = _next_id(); | |||||
| _nodes_by_id[op._id] = op; | _nodes_by_id[op._id] = op; | ||||
| _nodes_by_name[op.name] = op; | _nodes_by_name[op.name] = op; | ||||
| _version = Math.Max(_version, op._id); | _version = Math.Max(_version, op._id); | ||||
| @@ -13,7 +13,7 @@ namespace Tensorflow | |||||
| public Graph graph { get; } | public Graph graph { get; } | ||||
| public int _id => _id_value; | public int _id => _id_value; | ||||
| private int _id_value; | |||||
| public int _id_value; | |||||
| public string type => OpType; | public string type => OpType; | ||||
| public Operation op => this; | public Operation op => this; | ||||
| @@ -46,8 +46,6 @@ namespace Tensorflow | |||||
| _outputs = new Tensor[NumOutputs]; | _outputs = new Tensor[NumOutputs]; | ||||
| for (int i = 0; i < NumOutputs; i++) | for (int i = 0; i < NumOutputs; i++) | ||||
| _outputs[i] = new Tensor(this, i, OutputType(i)); | _outputs[i] = new Tensor(this, i, OutputType(i)); | ||||
| graph._add_op(this); | |||||
| } | } | ||||
| public Operation(Graph g, string opType, string oper_name) | public Operation(Graph g, string opType, string oper_name) | ||||
| @@ -100,8 +98,6 @@ namespace Tensorflow | |||||
| } | } | ||||
| // This will be set by self.inputs. | // This will be set by self.inputs. | ||||
| _id_value = graph._next_id(); | |||||
| if(op_def == null) | if(op_def == null) | ||||
| op_def = g.GetOpDef(node_def.Op); | op_def = g.GetOpDef(node_def.Op); | ||||
| @@ -88,8 +88,8 @@ namespace Tensorflow | |||||
| public static IEnumerable<(T, T)> zip<T>(NDArray t1, NDArray t2) | public static IEnumerable<(T, T)> zip<T>(NDArray t1, NDArray t2) | ||||
| { | { | ||||
| int index = 0; | |||||
| yield return(t1.Data<T>(index), t2.Data<T>(index)); | |||||
| for (int i = 0; i < t1.size; i++) | |||||
| yield return (t1.Data<T>(i), t2.Data<T>(i)); | |||||
| } | } | ||||
| public static IEnumerable<(T1, T2)> zip<T1, T2>(IList<T1> t1, IList<T2> t2) | public static IEnumerable<(T1, T2)> zip<T1, T2>(IList<T1> t1, IList<T2> t2) | ||||
| @@ -59,6 +59,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); | var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); | ||||
| var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); | var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); | ||||
| switch (subfeed_val) | switch (subfeed_val) | ||||
| { | { | ||||
| case IntPtr pointer: | case IntPtr pointer: | ||||
| @@ -86,6 +87,7 @@ namespace Tensorflow | |||||
| Console.WriteLine($"can't handle data type of subfeed_val"); | Console.WriteLine($"can't handle data type of subfeed_val"); | ||||
| throw new NotImplementedException("_run subfeed"); | throw new NotImplementedException("_run subfeed"); | ||||
| } | } | ||||
| feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); | feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); | ||||
| } | } | ||||
| } | } | ||||
| @@ -45,7 +45,7 @@ Upgraded to TensorFlow 1.13 RC2. | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="Google.Protobuf" Version="3.6.1" /> | <PackageReference Include="Google.Protobuf" Version="3.6.1" /> | ||||
| <PackageReference Include="NumSharp" Version="0.7.2" /> | |||||
| <PackageReference Include="NumSharp" Version="0.7.3" /> | |||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| @@ -52,20 +52,45 @@ namespace Tensorflow | |||||
| var dataType = ToTFDataType(nd.dtype); | var dataType = ToTFDataType(nd.dtype); | ||||
| // shape | // shape | ||||
| var dims = nd.shape.Select(x => (long)x).ToArray(); | var dims = nd.shape.Select(x => (long)x).ToArray(); | ||||
| var nd1 = nd.ravel(); | |||||
| switch (nd.dtype.Name) | switch (nd.dtype.Name) | ||||
| { | { | ||||
| case "Int16": | case "Int16": | ||||
| Marshal.Copy(nd.ravel().Data<short>(), 0, dotHandle, nd.size); | |||||
| Marshal.Copy(nd1.Data<short>(), 0, dotHandle, nd.size); | |||||
| break; | break; | ||||
| case "Int32": | case "Int32": | ||||
| Marshal.Copy(nd.ravel().Data<int>(), 0, dotHandle, nd.size); | |||||
| Marshal.Copy(nd1.Data<int>(), 0, dotHandle, nd.size); | |||||
| break; | break; | ||||
| case "Single": | case "Single": | ||||
| Marshal.Copy(nd.ravel().Data<float>(), 0, dotHandle, nd.size); | |||||
| Marshal.Copy(nd1.Data<float>(), 0, dotHandle, nd.size); | |||||
| /*if (nd.size > 1) | |||||
| { | |||||
| var bb = nd.Data<byte>(); | |||||
| var bytes = Marshal.AllocHGlobal(bb.Length); | |||||
| Marshal.Copy(bb, 0, bytes, bb.Length); | |||||
| ulong bytes_len = c_api.TF_StringEncodedSize((ulong)bb.Length); | |||||
| var dataTypeByte = ToTFDataType(nd.dtype); | |||||
| // shape | |||||
| var dims2 = nd.shape.Select(x => (long)x).ToArray(); | |||||
| var tfHandle2 = c_api.TF_AllocateTensor(dataTypeByte, | |||||
| dims2, | |||||
| nd.ndim, | |||||
| bytes_len + sizeof(Int64)); | |||||
| dotHandle = c_api.TF_TensorData(tfHandle2); | |||||
| Marshal.WriteInt64(dotHandle, 0); | |||||
| c_api.TF_StringEncode(bytes, (ulong)bb.Length, dotHandle + sizeof(Int64), bytes_len, status); | |||||
| return tfHandle2; | |||||
| } | |||||
| else | |||||
| { | |||||
| Marshal.Copy(nd1.Data<float>(), 0, dotHandle, nd.size); | |||||
| }*/ | |||||
| break; | break; | ||||
| case "Double": | case "Double": | ||||
| Marshal.Copy(nd.ravel().Data<double>(), 0, dotHandle, nd.size); | |||||
| Marshal.Copy(nd1.Data<double>(), 0, dotHandle, nd.size); | |||||
| break; | break; | ||||
| //case "Byte": | //case "Byte": | ||||
| /*var bb = nd.Data<byte>(); | /*var bb = nd.Data<byte>(); | ||||
| @@ -119,7 +144,7 @@ namespace Tensorflow | |||||
| dims, | dims, | ||||
| dims.Length, | dims.Length, | ||||
| dotHandle, | dotHandle, | ||||
| size, | |||||
| (UIntPtr)size, | |||||
| deallocator, | deallocator, | ||||
| ref deallocator_called); | ref deallocator_called); | ||||
| @@ -6,83 +6,70 @@ namespace Tensorflow | |||||
| { | { | ||||
| public partial class Tensor | public partial class Tensor | ||||
| { | { | ||||
| public static Tensor operator +(Tensor x, Tensor y) | |||||
| { | |||||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", "add", new Tensor[] { x, y }), scope => | |||||
| { | |||||
| return gen_math_ops.add(x, y, scope); | |||||
| }); | |||||
| } | |||||
| public static Tensor operator +(Tensor x, int y) | |||||
| { | |||||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", "add", new object[] { x, y }), scope => | |||||
| { | |||||
| var y1 = ops.convert_to_tensor(y, x.dtype.as_base_dtype(), name: "y"); | |||||
| return gen_math_ops.add(x, y1, scope); | |||||
| }); | |||||
| } | |||||
| public static Tensor operator +(Tensor x, Tensor y) => BinaryOpWrapper("add", x, y); | |||||
| public static Tensor operator +(Tensor x, int y) => BinaryOpWrapper("add", x, y); | |||||
| public static Tensor operator -(Tensor t1) => gen_math_ops.neg(t1); | public static Tensor operator -(Tensor t1) => gen_math_ops.neg(t1); | ||||
| public static Tensor operator -(Tensor t1, Tensor t2) => gen_math_ops.sub(t1, t2); | |||||
| public static Tensor operator -(Tensor t1, int t2) => gen_math_ops.sub(t1, t2); | |||||
| public static Tensor operator -(Tensor t1, double t2) => gen_math_ops.sub(t1, t2); | |||||
| public static Tensor operator *(double x, Tensor y) | |||||
| { | |||||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", "mul", new { x, y }), | |||||
| scope => | |||||
| { | |||||
| var x1 = ops.convert_to_tensor(x, y.dtype.as_base_dtype(), name: "x"); | |||||
| return gen_math_ops.mul(x1, y, name: scope); | |||||
| }); | |||||
| } | |||||
| public static Tensor operator -(Tensor x, Tensor y) => BinaryOpWrapper("sub", x, y); | |||||
| public static Tensor operator -(Tensor x, int y) => BinaryOpWrapper("sub", x, y); | |||||
| public static Tensor operator -(Tensor x, double y) => BinaryOpWrapper("sub", x, y); | |||||
| public static Tensor operator *(Tensor x, Tensor y) | |||||
| { | |||||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", "mul", new Tensor[] { x, y }), scope => | |||||
| { | |||||
| return gen_math_ops.mul(x, y, name: scope); | |||||
| }); | |||||
| } | |||||
| public static Tensor operator *(float x, Tensor y) => BinaryOpWrapper("mul", x, y); | |||||
| public static Tensor operator *(double x, Tensor y) => BinaryOpWrapper("mul", x, y); | |||||
| public static Tensor operator *(Tensor x, Tensor y) => BinaryOpWrapper("mul", x, y); | |||||
| public static Tensor operator *(Tensor x, int y) => BinaryOpWrapper("mul", x, y); | |||||
| public static Tensor operator *(Tensor x, int y) | |||||
| { | |||||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", "mul", new object[] { x, y }), scope => | |||||
| { | |||||
| var y1 = ops.convert_to_tensor(y, x.dtype.as_base_dtype(), name: "y"); | |||||
| return gen_math_ops.mul(x, y1, name: scope); | |||||
| }); | |||||
| } | |||||
| public static Tensor operator /(Tensor x, Tensor y) => BinaryOpWrapper("truediv", x, y); | |||||
| public static Tensor operator /(Tensor x, float y) => BinaryOpWrapper("truediv", x, y); | |||||
| public static Tensor operator /(Tensor x, double y) => BinaryOpWrapper("truediv", x, y); | |||||
| public static Tensor operator /(Tensor x, Tensor y) | |||||
| { | |||||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope("truediv/", "truediv", new Tensor[] { x, y }), scope => | |||||
| { | |||||
| return gen_math_ops.real_div(x, y, scope); | |||||
| }); | |||||
| } | |||||
| public static Tensor operator %(Tensor x, Tensor y) => BinaryOpWrapper("mod", x, y); | |||||
| public static Tensor operator /(Tensor x, double y) | |||||
| { | |||||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope("truediv/", "truediv", new object[] { x, y }), scope => | |||||
| { | |||||
| var y1 = ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"); | |||||
| return gen_math_ops.real_div(x, y1, scope); | |||||
| }); | |||||
| } | |||||
| public static Tensor operator >(Tensor x, int y) => gen_array_ops.greater(x, y); | |||||
| public static Tensor operator >(Tensor x, double y) => gen_array_ops.greater(x, y); | |||||
| public static Tensor operator <(Tensor x, int y) => gen_array_ops.less(x, y); | |||||
| public static Tensor operator <(Tensor x, double y) => gen_array_ops.less(x, y); | |||||
| public static Tensor operator %(Tensor x, Tensor y) | |||||
| private static Tensor BinaryOpWrapper<Tx, Ty>(string name, Tx x, Ty y) | |||||
| { | { | ||||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", "mod", new object[] { x, y }), scope => | |||||
| TF_DataType dtype = TF_DataType.DtInvalid; | |||||
| if (x is Tensor tl) | |||||
| dtype = tl.dtype.as_base_dtype(); | |||||
| if( y is Tensor tr) | |||||
| dtype = tr.dtype.as_base_dtype(); | |||||
| var namescope = new ops.name_scope("", name, new { x, y }); | |||||
| return Python.with<ops.name_scope, Tensor>(namescope, scope => | |||||
| { | { | ||||
| return gen_math_ops.floor_mod(x, y, scope); | |||||
| Tensor result = null; | |||||
| var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x"); | |||||
| var y1 = ops.convert_to_tensor(y, dtype: dtype, name: "y"); | |||||
| switch (name) | |||||
| { | |||||
| case "add": | |||||
| result = gen_math_ops.add(x1, y1, name: scope); | |||||
| break; | |||||
| case "truediv": | |||||
| result = gen_math_ops.real_div(x1, y1, name: scope); | |||||
| break; | |||||
| case "mul": | |||||
| result = gen_math_ops.mul(x1, y1, name: scope); | |||||
| break; | |||||
| case "sub": | |||||
| result = gen_math_ops.sub(x1, y1, name: scope); | |||||
| break; | |||||
| case "mod": | |||||
| result = gen_math_ops.floor_mod(x1, y1, name: scope); | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException($"BinaryOpWrapper: {name} - {typeof(Tx).Name}, {typeof(Ty)}"); | |||||
| } | |||||
| return result; | |||||
| }); | }); | ||||
| } | } | ||||
| public static Tensor operator >(Tensor x, int y) => gen_array_ops.greater(x, y); | |||||
| public static Tensor operator >(Tensor x, double y) => gen_array_ops.greater(x, y); | |||||
| public static Tensor operator <(Tensor x, int y) => gen_array_ops.less(x, y); | |||||
| public static Tensor operator <(Tensor x, double y) => gen_array_ops.less(x, y); | |||||
| } | } | ||||
| } | } | ||||
| @@ -55,7 +55,7 @@ namespace Tensorflow | |||||
| /// <param name="deallocator_arg"></param> | /// <param name="deallocator_arg"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, ulong len, Deallocator deallocator, ref bool deallocator_arg); | |||||
| public static extern IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, UIntPtr len, Deallocator deallocator, ref bool deallocator_arg); | |||||
| /// <summary> | /// <summary> | ||||
| /// Return the number of dimensions that the tensor has. | /// Return the number of dimensions that the tensor has. | ||||
| @@ -96,19 +96,19 @@ namespace Tensorflow | |||||
| if (values.GetType().IsArray) | if (values.GetType().IsArray) | ||||
| nparray = np.array((int[])values, np_dt); | nparray = np.array((int[])values, np_dt); | ||||
| else | else | ||||
| nparray = (int)values; | |||||
| nparray = Convert.ToInt32(values); | |||||
| break; | break; | ||||
| case "Single": | case "Single": | ||||
| if (values.GetType().IsArray) | if (values.GetType().IsArray) | ||||
| nparray = np.array((float[])values, np_dt); | nparray = np.array((float[])values, np_dt); | ||||
| else | else | ||||
| nparray = (float)values; | |||||
| nparray = Convert.ToSingle(values); | |||||
| break; | break; | ||||
| case "Double": | case "Double": | ||||
| nparray = (double)values; | |||||
| nparray = Convert.ToDouble(values); | |||||
| break; | break; | ||||
| case "String": | case "String": | ||||
| nparray = values.ToString(); | |||||
| nparray = Convert.ToString(values); | |||||
| break; | break; | ||||
| default: | default: | ||||
| throw new NotImplementedException("make_tensor_proto Not Implemented"); | throw new NotImplementedException("make_tensor_proto Not Implemented"); | ||||
| @@ -10,41 +10,47 @@ namespace TensorFlowNET.Examples | |||||
| /// A linear regression learning algorithm example using TensorFlow library. | /// A linear regression learning algorithm example using TensorFlow library. | ||||
| /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/linear_regression.py | /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/linear_regression.py | ||||
| /// </summary> | /// </summary> | ||||
| public class LinearRegression : IExample | |||||
| public class LinearRegression : Python, IExample | |||||
| { | { | ||||
| private NumPyRandom rng = np.random; | private NumPyRandom rng = np.random; | ||||
| public void Run() | public void Run() | ||||
| { | { | ||||
| var graph = tf.Graph().as_default(); | |||||
| // Parameters | // Parameters | ||||
| double learning_rate = 0.01; | |||||
| float learning_rate = 0.01f; | |||||
| int training_epochs = 1000; | int training_epochs = 1000; | ||||
| int display_step = 50; | |||||
| int display_step = 1; | |||||
| // Training Data | // Training Data | ||||
| var train_X = np.array(3.3, 4.4, 5.5, 6.71, 6.93, 4.168, 9.779, 6.182, 7.59, 2.167, | |||||
| 7.042, 10.791, 5.313, 7.997, 5.654, 9.27, 3.1); | |||||
| var train_Y = np.array(1.7, 2.76, 2.09, 3.19, 1.694, 1.573, 3.366, 2.596, 2.53, 1.221, | |||||
| 2.827, 3.465, 1.65, 2.904, 2.42, 2.94, 1.3); | |||||
| var train_X = np.array(3.3f, 4.4f, 5.5f, 6.71f, 6.93f, 4.168f, 9.779f, 6.182f, 7.59f, 2.167f, | |||||
| 7.042f, 10.791f, 5.313f, 7.997f, 5.654f, 9.27f, 3.1f); | |||||
| var train_Y = np.array(1.7f, 2.76f, 2.09f, 3.19f, 1.694f, 1.573f, 3.366f, 2.596f, 2.53f, 1.221f, | |||||
| 2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f); | |||||
| var n_samples = train_X.shape[0]; | var n_samples = train_X.shape[0]; | ||||
| // tf Graph Input | // tf Graph Input | ||||
| var X = tf.placeholder(tf.float64); | |||||
| var Y = tf.placeholder(tf.float64); | |||||
| var X = tf.placeholder(tf.float32); | |||||
| var Y = tf.placeholder(tf.float32); | |||||
| // Set model weights | // Set model weights | ||||
| var W = tf.Variable(rng.randn<double>(), name: "weight"); | |||||
| var b = tf.Variable(rng.randn<double>(), name: "bias"); | |||||
| //var rnd1 = rng.randn<float>(); | |||||
| //var rnd2 = rng.randn<float>(); | |||||
| var W = tf.Variable(-0.06f, name: "weight"); | |||||
| var b = tf.Variable(-0.73f, name: "bias"); | |||||
| var mul = tf.multiply(X, W); | var mul = tf.multiply(X, W); | ||||
| var pred = tf.add(mul, b); | var pred = tf.add(mul, b); | ||||
| // Mean squared error | // Mean squared error | ||||
| var sub = pred - Y; | var sub = pred - Y; | ||||
| var pow = tf.pow(sub, 2); | |||||
| var pow = tf.pow(sub, 2.0f); | |||||
| var reduce = tf.reduce_sum(pow); | var reduce = tf.reduce_sum(pow); | ||||
| var cost = reduce / (2d * n_samples); | |||||
| var cost = reduce / (2.0f * n_samples); | |||||
| // import graph | |||||
| // radient descent | // radient descent | ||||
| // Note, minimize() knows to modify W and b because Variable objects are trainable=True by default | // Note, minimize() knows to modify W and b because Variable objects are trainable=True by default | ||||
| @@ -55,7 +61,7 @@ namespace TensorFlowNET.Examples | |||||
| var init = tf.global_variables_initializer(); | var init = tf.global_variables_initializer(); | ||||
| // Start training | // Start training | ||||
| Python.with<Session>(tf.Session(), sess => | |||||
| Python.with<Session>(tf.Session(graph), sess => | |||||
| { | { | ||||
| // Run the initializer | // Run the initializer | ||||
| sess.run(init); | sess.run(init); | ||||
| @@ -63,11 +69,12 @@ namespace TensorFlowNET.Examples | |||||
| // Fit all training data | // Fit all training data | ||||
| for (int epoch = 0; epoch < training_epochs; epoch++) | for (int epoch = 0; epoch < training_epochs; epoch++) | ||||
| { | { | ||||
| foreach (var (x, y) in Python.zip<double>(train_X, train_Y)) | |||||
| foreach (var (x, y) in zip<float>(train_X, train_Y)) | |||||
| { | { | ||||
| sess.run(optimizer, | sess.run(optimizer, | ||||
| new FeedItem(X, x), | new FeedItem(X, x), | ||||
| new FeedItem(Y, y)); | new FeedItem(Y, y)); | ||||
| var w = sess.run(W); | |||||
| } | } | ||||
| // Display logs per epoch step | // Display logs per epoch step | ||||
| @@ -6,7 +6,7 @@ | |||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="NumSharp" Version="0.7.2" /> | |||||
| <PackageReference Include="NumSharp" Version="0.7.3" /> | |||||
| <PackageReference Include="TensorFlow.NET" Version="0.3.0" /> | <PackageReference Include="TensorFlow.NET" Version="0.3.0" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| @@ -12,6 +12,7 @@ namespace TensorFlowNET.UnitTest | |||||
| [TestMethod] | [TestMethod] | ||||
| public void Gradients() | public void Gradients() | ||||
| { | { | ||||
| var graph = tf.Graph().as_default(); | |||||
| var a = tf.constant(0.0); | var a = tf.constant(0.0); | ||||
| var b = 2.0 * a; | var b = 2.0 * a; | ||||
| Assert.AreEqual(b.name, "mul:0"); | Assert.AreEqual(b.name, "mul:0"); | ||||
| @@ -19,7 +19,7 @@ | |||||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.9.0" /> | <PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.9.0" /> | ||||
| <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | ||||
| <PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | <PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | ||||
| <PackageReference Include="NumSharp" Version="0.7.2" /> | |||||
| <PackageReference Include="NumSharp" Version="0.7.3" /> | |||||
| <PackageReference Include="TensorFlow.NET" Version="0.3.0" /> | <PackageReference Include="TensorFlow.NET" Version="0.3.0" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| @@ -60,7 +60,7 @@ namespace TensorFlowNET.UnitTest | |||||
| EXPECT_EQ((int)tensor.shape[0], nd.shape[0]); | EXPECT_EQ((int)tensor.shape[0], nd.shape[0]); | ||||
| EXPECT_EQ((int)tensor.shape[1], nd.shape[1]); | EXPECT_EQ((int)tensor.shape[1], nd.shape[1]); | ||||
| EXPECT_EQ(tensor.bytesize, (ulong)nd.size * sizeof(float)); | EXPECT_EQ(tensor.bytesize, (ulong)nd.size * sizeof(float)); | ||||
| Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), array)); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), new float[] { 1, 4, 2, 5, 3, 6 })); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -10,7 +10,6 @@ namespace TensorFlowNET.UnitTest | |||||
| [TestClass] | [TestClass] | ||||
| public class TrainSaverTest : Python | public class TrainSaverTest : Python | ||||
| { | { | ||||
| [TestMethod] | |||||
| public void ExportGraph() | public void ExportGraph() | ||||
| { | { | ||||
| var v = tf.Variable(0, name: "my_variable"); | var v = tf.Variable(0, name: "my_variable"); | ||||
| @@ -18,7 +17,6 @@ namespace TensorFlowNET.UnitTest | |||||
| tf.train.write_graph(sess.graph, "/tmp/my-model", "train1.pbtxt"); | tf.train.write_graph(sess.graph, "/tmp/my-model", "train1.pbtxt"); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void ImportGraph() | public void ImportGraph() | ||||
| { | { | ||||
| with<Session>(tf.Session(), sess => | with<Session>(tf.Session(), sess => | ||||
| @@ -27,7 +25,6 @@ namespace TensorFlowNET.UnitTest | |||||
| }); | }); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void ImportSavedModel() | public void ImportSavedModel() | ||||
| { | { | ||||
| with<Session>(Session.LoadFromSavedModel("mobilenet"), sess => | with<Session>(Session.LoadFromSavedModel("mobilenet"), sess => | ||||
| @@ -36,14 +33,12 @@ namespace TensorFlowNET.UnitTest | |||||
| }); | }); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void ImportGraphDefFromPbFile() | public void ImportGraphDefFromPbFile() | ||||
| { | { | ||||
| var g = new Graph(); | var g = new Graph(); | ||||
| var status = g.Import("mobilenet/saved_model.pb"); | var status = g.Import("mobilenet/saved_model.pb"); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void Save1() | public void Save1() | ||||
| { | { | ||||
| var w1 = tf.Variable(0, name: "save1"); | var w1 = tf.Variable(0, name: "save1"); | ||||
| @@ -63,7 +58,6 @@ namespace TensorFlowNET.UnitTest | |||||
| }); | }); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void Save2() | public void Save2() | ||||
| { | { | ||||
| var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer); | var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer); | ||||