| @@ -322,6 +322,8 @@ namespace Tensorflow | |||
| attr_value.Shape = val1.as_proto(); | |||
| else if(value is long[] val2) | |||
| attr_value.Shape = tensor_util.as_shape(val2); | |||
| else if (value is int[] val3) | |||
| attr_value.Shape = tensor_util.as_shape(val3); | |||
| break; | |||
| default: | |||
| @@ -34,10 +34,9 @@ namespace Tensorflow | |||
| private static Tensor _constant_if_small<T>(T value, Shape shape, TF_DataType dtype, string name) | |||
| { | |||
| Tensor tShape = null; | |||
| var nd = np.zeros<T>(shape); | |||
| if (shape.Size < 1000) | |||
| { | |||
| return constant_op.constant(nd, name: name); | |||
| return constant_op.constant(value, shape: shape, dtype: dtype, name: name); | |||
| } | |||
| else | |||
| { | |||
| @@ -43,7 +43,7 @@ TensorFlow 1.13 RC.</PackageReleaseNotes> | |||
| <ItemGroup> | |||
| <PackageReference Include="Google.Protobuf" Version="3.6.1" /> | |||
| <PackageReference Include="NumSharp" Version="0.7.0" /> | |||
| <PackageReference Include="NumSharp" Version="0.7.1" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| @@ -56,9 +56,9 @@ namespace Tensorflow | |||
| switch (values) | |||
| { | |||
| /*case bool boolVal: | |||
| case bool boolVal: | |||
| nparray = boolVal; | |||
| break;*/ | |||
| break; | |||
| case int intVal: | |||
| nparray = intVal; | |||
| break; | |||
| @@ -74,6 +74,9 @@ namespace Tensorflow | |||
| case string strVal: | |||
| nparray = strVal; | |||
| break; | |||
| case string[] strVals: | |||
| nparray = strVals; | |||
| break; | |||
| default: | |||
| throw new Exception("make_tensor_proto Not Implemented"); | |||
| } | |||
| @@ -100,7 +103,8 @@ namespace Tensorflow | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException("make_tensor_proto shape not implemented"); | |||
| shape_size = new TensorShape(shape).Size; | |||
| is_same_size = shape_size == nparray.size; | |||
| } | |||
| var tensor_proto = new tensor_pb2.TensorProto | |||
| @@ -111,41 +115,17 @@ namespace Tensorflow | |||
| if (is_same_size && _TENSOR_CONTENT_TYPES.Contains(numpy_dtype) && shape_size > 1) | |||
| { | |||
| var bytes = new List<byte>(); | |||
| var nd2 = nparray.ravel(); | |||
| switch (nparray.dtype.Name) | |||
| { | |||
| case "Int32": | |||
| nd2.Data<int>().Select(x => | |||
| { | |||
| bytes.AddRange(BitConverter.GetBytes(x)); | |||
| return x; | |||
| }).ToArray(); | |||
| break; | |||
| case "Single": | |||
| nd2.Data<float>().Select(x => | |||
| { | |||
| bytes.AddRange(BitConverter.GetBytes(x)); | |||
| return x; | |||
| }).ToArray(); | |||
| break; | |||
| case "Double": | |||
| nd2.Data<double>().Select(x => | |||
| { | |||
| bytes.AddRange(BitConverter.GetBytes(x)); | |||
| return x; | |||
| }).ToArray(); | |||
| break; | |||
| default: | |||
| throw new Exception("make_tensor_proto Not Implemented"); | |||
| } | |||
| byte[] bytes = nparray.ToByteArray(); | |||
| tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes.ToArray()); | |||
| return tensor_proto; | |||
| } | |||
| if (numpy_dtype == TF_DataType.TF_STRING && !(values is NDArray) && values is string str) | |||
| if (numpy_dtype == TF_DataType.TF_STRING && !(values is NDArray)) | |||
| { | |||
| tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFromUtf8(str)); | |||
| if (values is string str) | |||
| tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFromUtf8(str)); | |||
| else if (values is string[] str_values) | |||
| tensor_proto.StringVal.AddRange(str_values.Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x))); | |||
| return tensor_proto; | |||
| } | |||
| @@ -123,8 +123,6 @@ namespace Tensorflow | |||
| Version = _write_version | |||
| }; | |||
| }); | |||
| } | |||
| public Tensor _AddSaveOps(Tensor filename_tensor, SaveableObject[] saveables) | |||
| @@ -6,7 +6,7 @@ | |||
| </PropertyGroup> | |||
| <ItemGroup> | |||
| <PackageReference Include="NumSharp" Version="0.7.0" /> | |||
| <PackageReference Include="NumSharp" Version="0.7.1" /> | |||
| <PackageReference Include="TensorFlow.NET" Version="0.1.0" /> | |||
| </ItemGroup> | |||
| @@ -58,6 +58,7 @@ namespace TensorFlowNET.UnitTest | |||
| var data = result.Data<int>(); | |||
| Assert.AreEqual(0, data[0]); | |||
| Assert.AreEqual(0, data[500]); | |||
| Assert.AreEqual(0, data[result.size - 1]); | |||
| }); | |||
| } | |||
| @@ -109,5 +110,15 @@ namespace TensorFlowNET.UnitTest | |||
| //c_api.TF_StringDecode(str, (ulong)str.Length, IntPtr.Zero, ref dst_len, status); | |||
| } | |||
| /// <summary> | |||
| /// tensorflow\c\c_api_test.cc | |||
| /// TestEncodeDecode | |||
| /// </summary> | |||
| [TestMethod] | |||
| public void EncodeDecode() | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -19,7 +19,7 @@ | |||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.9.0" /> | |||
| <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | |||
| <PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | |||
| <PackageReference Include="NumSharp" Version="0.7.0" /> | |||
| <PackageReference Include="NumSharp" Version="0.7.1" /> | |||
| <PackageReference Include="TensorFlow.NET" Version="0.1.0" /> | |||
| </ItemGroup> | |||
| @@ -10,11 +10,11 @@ namespace TensorFlowNET.UnitTest | |||
| public class TrainSaverTest : Python | |||
| { | |||
| [TestMethod] | |||
| public void WriteGraph() | |||
| public void ExportGraph() | |||
| { | |||
| var v = tf.Variable(0, name: "my_variable"); | |||
| var sess = tf.Session(); | |||
| tf.train.write_graph(sess.graph, "/tmp/my-model", "train.pbtxt"); | |||
| tf.train.write_graph(sess.graph, "/tmp/my-model", "train1.pbtxt"); | |||
| } | |||
| [TestMethod] | |||
| @@ -22,14 +22,13 @@ namespace TensorFlowNET.UnitTest | |||
| { | |||
| var v = tf.Variable(0, name: "my_variable"); | |||
| var sess = tf.Session(); | |||
| tf.train.write_graph(sess.graph, "/tmp/my-model", "train.pbtxt"); | |||
| tf.train.write_graph(sess.graph, "/tmp/my-model", "train2.pbtxt"); | |||
| } | |||
| [TestMethod] | |||
| public void SaveSimple() | |||
| public void Save1() | |||
| { | |||
| var w1 = tf.Variable(tf.random_normal(new int[] { 2 }), name: "w1"); | |||
| var w2 = tf.Variable(tf.random_normal(new int[] { 5 }), name: "w2"); | |||
| var w1 = tf.Variable(0, name: "save1"); | |||
| var init_op = tf.global_variables_initializer(); | |||
| @@ -41,13 +40,13 @@ namespace TensorFlowNET.UnitTest | |||
| sess.run(init_op); | |||
| // Save the variables to disk. | |||
| var save_path = saver.save(sess, "/tmp/model.ckpt"); | |||
| var save_path = saver.save(sess, "/tmp/model1.ckpt"); | |||
| Console.WriteLine($"Model saved in path: {save_path}"); | |||
| }); | |||
| } | |||
| [TestMethod] | |||
| public void Save() | |||
| public void Save2() | |||
| { | |||
| var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer); | |||
| var v2 = tf.get_variable("v2", shape: new TensorShape(5), initializer: tf.zeros_initializer); | |||
| @@ -69,7 +68,7 @@ namespace TensorFlowNET.UnitTest | |||
| dec_v2.op.run(); | |||
| // Save the variables to disk. | |||
| var save_path = saver.save(sess, "/tmp/model.ckpt"); | |||
| var save_path = saver.save(sess, "/tmp/model2.ckpt"); | |||
| Console.WriteLine($"Model saved in path: {save_path}"); | |||
| }); | |||
| } | |||