diff --git a/src/TensorFlowNET.Core/APIs/tf.io.cs b/src/TensorFlowNET.Core/APIs/tf.io.cs index 40da04b1..25d9cfe8 100644 --- a/src/TensorFlowNET.Core/APIs/tf.io.cs +++ b/src/TensorFlowNET.Core/APIs/tf.io.cs @@ -21,9 +21,32 @@ namespace Tensorflow { public partial class tensorflow { + public IoApi io { get; } = new IoApi(); + + public class IoApi + { + io_ops ops; + public IoApi() + { + ops = new io_ops(); + } + + public Tensor read_file(string filename, string name = null) + => ops.read_file(filename, name); + + public Tensor read_file(Tensor filename, string name = null) + => ops.read_file(filename, name); + + public Operation save_v2(Tensor prefix, string[] tensor_names, + string[] shape_and_slices, Tensor[] tensors, string name = null) + => ops.save_v2(prefix, tensor_names, shape_and_slices, tensors, name: name); + + public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, + string[] shape_and_slices, TF_DataType[] dtypes, string name = null) + => ops.restore_v2(prefix, tensor_names, shape_and_slices, dtypes, name: name); + } + public GFile gfile = new GFile(); - public Tensor read_file(string filename, string name = null) => gen_io_ops.read_file(filename, name); - public Tensor read_file(Tensor filename, string name = null) => gen_io_ops.read_file(filename, name); public ITensorOrOperation[] import_graph_def(GraphDef graph_def, Dictionary input_map = null, diff --git a/src/TensorFlowNET.Core/APIs/tf.strings.cs b/src/TensorFlowNET.Core/APIs/tf.strings.cs index 38d92803..e19136a9 100644 --- a/src/TensorFlowNET.Core/APIs/tf.strings.cs +++ b/src/TensorFlowNET.Core/APIs/tf.strings.cs @@ -21,12 +21,28 @@ namespace Tensorflow { public partial class tensorflow { - public strings_internal strings = new strings_internal(); - public class strings_internal + public StringsApi strings { get; } = new StringsApi(); + + public class StringsApi { + string_ops ops = new string_ops(); + + /// + /// Return substrings from `Tensor` of strings. + /// + /// + /// + /// + /// + /// + /// public Tensor substr(Tensor input, int pos, int len, string name = null, string @uint = "BYTE") - => string_ops.substr(input, pos, len, name: name, @uint: @uint); + => ops.substr(input, pos, len, @uint: @uint, name: name); + + public Tensor substr(string input, int pos, int len, + string name = null, string @uint = "BYTE") + => ops.substr(input, pos, len, @uint: @uint, name: name); } } } diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs index 1c5c344f..78a63b77 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs @@ -47,7 +47,7 @@ namespace Tensorflow.Eager status.Check(true); } } - if (status.ok()) + if (status.ok() && attrs != null) SetOpAttrs(op, attrs); var outputs = new IntPtr[num_outputs]; diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs index b9ca57cf..0385e588 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs @@ -204,9 +204,6 @@ namespace Tensorflow.Eager input_handle = input.EagerTensorHandle; flattened_inputs.Add(input); break; - case EagerTensor[] input_list: - input_handle = input_list[0].EagerTensorHandle; - break; default: var tensor = tf.convert_to_tensor(inputs); input_handle = (tensor as EagerTensor).EagerTensorHandle; diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs index 2852c05c..5df45e61 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs @@ -376,6 +376,16 @@ namespace Tensorflow { return tf_with(ops.name_scope(name, "cond", new { pred }), delegate { + if (tf.context.executing_eagerly()) + { + if (pred.ToArray()[0]) + return true_fn() as Tensor; + else + return false_fn() as Tensor; + + return null; + } + // Add the Switch to the graph. var switch_result= @switch(pred, pred); var (p_2, p_1 )= (switch_result[0], switch_result[1]); @@ -450,6 +460,16 @@ namespace Tensorflow { return tf_with(ops.name_scope(name, "cond", new { pred }), delegate { + if (tf.context.executing_eagerly()) + { + if (pred.ToArray()[0]) + return true_fn() as Tensor[]; + else + return false_fn() as Tensor[]; + + return null; + } + // Add the Switch to the graph. var switch_result = @switch(pred, pred); var p_2 = switch_result[0]; diff --git a/src/TensorFlowNET.Core/Operations/gen_string_ops.cs b/src/TensorFlowNET.Core/Operations/gen_string_ops.cs deleted file mode 100644 index bb407e77..00000000 --- a/src/TensorFlowNET.Core/Operations/gen_string_ops.cs +++ /dev/null @@ -1,40 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using System; -using System.Collections.Generic; -using System.Text; -using static Tensorflow.Binding; - -namespace Tensorflow -{ - public class gen_string_ops - { - public static Tensor substr(Tensor input, int pos, int len, - string name = null, string @uint = "BYTE") - { - var _op = tf._op_def_lib._apply_op_helper("Substr", name: name, args: new - { - input, - pos, - len, - unit = @uint - }); - - return _op.output; - } - } -} diff --git a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs index f534df6a..2cbd3c39 100644 --- a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs +++ b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs @@ -16,6 +16,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; using Tensorflow.Operations; using static Tensorflow.Binding; @@ -63,7 +64,7 @@ namespace Tensorflow Func _bmp = () => { int bmp_channels = channels; - var signature = string_ops.substr(contents, 0, 2); + var signature = tf.strings.substr(contents, 0, 2); var is_bmp = math_ops.equal(signature, "BM", name: "is_bmp"); string decode_msg = "Unable to decode bytes as JPEG, PNG, GIF, or BMP"; var assert_decode = control_flow_ops.Assert(is_bmp, new string[] { decode_msg }); @@ -98,7 +99,7 @@ namespace Tensorflow return tf_with(ops.name_scope(name, "decode_image"), scope => { - substr = string_ops.substr(contents, 0, 3); + substr = tf.strings.substr(contents, 0, 3); return control_flow_ops.cond(is_jpeg(contents), _jpeg, check_png, name: "cond_jpeg"); }); } @@ -128,8 +129,11 @@ namespace Tensorflow { return tf_with(ops.name_scope(name, "is_jpeg"), scope => { - var substr = string_ops.substr(contents, 0, 3); - return math_ops.equal(substr, "\xff\xd8\xff", name: name); + var substr = tf.strings.substr(contents, 0, 3); + var jpg = Encoding.UTF8.GetString(new byte[] { 0xff, 0xd8, 0xff }); + var jpg_tensor = tf.constant(jpg); + var result = math_ops.equal(substr, jpg_tensor, name: name); + return result; }); } @@ -137,7 +141,7 @@ namespace Tensorflow { return tf_with(ops.name_scope(name, "is_png"), scope => { - var substr = string_ops.substr(contents, 0, 3); + var substr = tf.strings.substr(contents, 0, 3); return math_ops.equal(substr, @"\211PN", name: name); }); } diff --git a/src/TensorFlowNET.Core/Operations/gen_io_ops.cs b/src/TensorFlowNET.Core/Operations/io_ops.cs similarity index 60% rename from src/TensorFlowNET.Core/Operations/gen_io_ops.cs rename to src/TensorFlowNET.Core/Operations/io_ops.cs index d7462116..9b8a9889 100644 --- a/src/TensorFlowNET.Core/Operations/gen_io_ops.cs +++ b/src/TensorFlowNET.Core/Operations/io_ops.cs @@ -14,31 +14,45 @@ limitations under the License. ******************************************************************************/ +using Tensorflow.Eager; using static Tensorflow.Binding; namespace Tensorflow { - public class gen_io_ops + public class io_ops { - public static Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null) + public Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null) { var _op = tf._op_def_lib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors }); return _op; } - public static Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) + public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) { var _op = tf._op_def_lib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); return _op.outputs; } - public static Tensor read_file(T filename, string name = null) + public Tensor read_file(T filename, string name = null) { + if (tf.context.executing_eagerly()) + { + return read_file_eager_fallback(filename, name: name, tf.context); + } + var _op = tf._op_def_lib._apply_op_helper("ReadFile", name: name, args: new { filename }); return _op.outputs[0]; } + + private Tensor read_file_eager_fallback(T filename, string name = null, Context ctx = null) + { + var filename_tensor = ops.convert_to_tensor(filename, TF_DataType.TF_STRING); + var _inputs_flat = new[] { filename_tensor }; + + return tf._execute.execute(ctx, "ReadFile", 1, _inputs_flat, null, name: name)[0]; + } } } diff --git a/src/TensorFlowNET.Core/Operations/string_ops.cs b/src/TensorFlowNET.Core/Operations/string_ops.cs index ee46cf78..a0b46c48 100644 --- a/src/TensorFlowNET.Core/Operations/string_ops.cs +++ b/src/TensorFlowNET.Core/Operations/string_ops.cs @@ -17,6 +17,7 @@ using System; using System.Collections.Generic; using System.Text; +using static Tensorflow.Binding; namespace Tensorflow { @@ -31,8 +32,30 @@ namespace Tensorflow /// /// /// - public static Tensor substr(Tensor input, int pos, int len, - string name = null, string @uint = "BYTE") - => gen_string_ops.substr(input, pos, len, name: name, @uint: @uint); + public Tensor substr(T input, int pos, int len, + string @uint = "BYTE", string name = null) + { + if (tf.context.executing_eagerly()) + { + var input_tensor = tf.constant(input); + var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Substr", name, + null, + input, pos, len, + "unit", @uint); + + return results[0]; + } + + var _op = tf._op_def_lib._apply_op_helper("Substr", name: name, args: new + { + input, + pos, + len, + unit = @uint + }); + + return _op.output; + } } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs index aa9e7d90..1845e9fd 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs @@ -68,9 +68,9 @@ namespace Tensorflow throw new ArgumentException($"{nameof(Tensor)} can only be scalar."); IntPtr stringStartAddress = IntPtr.Zero; - UIntPtr dstLen = UIntPtr.Zero; + ulong dstLen = 0; - c_api.TF_StringDecode((byte*) this.buffer + 8, (UIntPtr) (this.bytesize), (byte**) &stringStartAddress, &dstLen, tf.status.Handle); + c_api.TF_StringDecode((byte*) this.buffer + 8, this.bytesize, (byte**) &stringStartAddress, ref dstLen, tf.status.Handle); tf.status.Check(true); var dstLenInt = checked((int) dstLen); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index 3d7e4cbc..d1a75338 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -453,7 +453,7 @@ namespace Tensorflow { var buffer = Encoding.UTF8.GetBytes(str); var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); - var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); + var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + sizeof(ulong))); AllocationType = AllocationType.Tensorflow; IntPtr tensor = c_api.TF_TensorData(handle); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs index 04b22a68..579ff566 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs @@ -235,13 +235,12 @@ namespace Tensorflow var buffer = new byte[size][]; var src = c_api.TF_TensorData(_handle); - var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize); src += (int)(size * 8); for (int i = 0; i < buffer.Length; i++) { IntPtr dst = IntPtr.Zero; - UIntPtr dstLen = UIntPtr.Zero; - var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, tf.status.Handle); + ulong dstLen = 0; + var read = c_api.TF_StringDecode((byte*)src, bytesize, (byte**)&dst, ref dstLen, tf.status.Handle); tf.status.Check(true); buffer[i] = new byte[(int)dstLen]; Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); @@ -254,5 +253,35 @@ namespace Tensorflow return _str; } + + public unsafe byte[][] StringBytes() + { + if (dtype != TF_DataType.TF_STRING) + throw new InvalidOperationException($"Unable to call StringData when dtype != TF_DataType.TF_STRING (dtype is {dtype})"); + + // + // TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes. + // [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes] + // + long size = 1; + foreach (var s in TensorShape.dims) + size *= s; + + var buffer = new byte[size][]; + var src = c_api.TF_TensorData(_handle); + src += (int)(size * 8); + for (int i = 0; i < buffer.Length; i++) + { + IntPtr dst = IntPtr.Zero; + ulong dstLen = 0; + var read = c_api.TF_StringDecode((byte*)src, bytesize, (byte**)&dst, ref dstLen, tf.status.Handle); + tf.status.Check(true); + buffer[i] = new byte[(int)dstLen]; + Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); + src += (int)read; + } + + return buffer; + } } } diff --git a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs index ebc2b192..c9dd5e13 100644 --- a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs @@ -207,7 +207,7 @@ namespace Tensorflow public static extern ulong TF_StringDecode(IntPtr src, ulong src_len, IntPtr dst, ref ulong dst_len, SafeStatusHandle status); [DllImport(TensorFlowLibName)] - public static extern unsafe UIntPtr TF_StringDecode(byte* src, UIntPtr src_len, byte** dst, UIntPtr* dst_len, SafeStatusHandle status); + public static extern unsafe ulong TF_StringDecode(byte* src, ulong src_len, byte** dst, ref ulong dst_len, SafeStatusHandle status); public static c_api.Deallocator EmptyDeallocator = FreeNothingDeallocator; diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index a7821c93..d3c28938 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -132,10 +132,22 @@ namespace Tensorflow switch (value) { + case EagerTensor val: + return val; case NDArray val: return new EagerTensor(val, ctx.device_name); case string val: return new EagerTensor(val, ctx.device_name); + case bool val: + return new EagerTensor(val, ctx.device_name); + case byte val: + return new EagerTensor(val, ctx.device_name); + case byte[] val: + return new EagerTensor(val, ctx.device_name); + case byte[,] val: + return new EagerTensor(val, ctx.device_name); + case byte[,,] val: + return new EagerTensor(val, ctx.device_name); case int val: return new EagerTensor(val, ctx.device_name); case int[] val: diff --git a/src/TensorFlowNET.Core/Training/Saving/BaseSaverBuilder.cs b/src/TensorFlowNET.Core/Training/Saving/BaseSaverBuilder.cs index 1aae389b..7ebf94d6 100644 --- a/src/TensorFlowNET.Core/Training/Saving/BaseSaverBuilder.cs +++ b/src/TensorFlowNET.Core/Training/Saving/BaseSaverBuilder.cs @@ -55,7 +55,7 @@ namespace Tensorflow if (_write_version == SaverDef.Types.CheckpointFormatVersion.V2) { - return gen_io_ops.save_v2(filename_tensor, tensor_names.ToArray(), tensor_slices.ToArray(), tensors.ToArray()); + return tf.io.save_v2(filename_tensor, tensor_names.ToArray(), tensor_slices.ToArray(), tensors.ToArray()); } else { @@ -76,7 +76,7 @@ namespace Tensorflow dtypes.Add(spec.dtype); } - return gen_io_ops.restore_v2(filename_tensor, names.ToArray(), slices.ToArray(), dtypes.ToArray()); + return tf.io.restore_v2(filename_tensor, names.ToArray(), slices.ToArray(), dtypes.ToArray()); } public virtual SaverDef _build_internal(IVariableV1[] names_to_saveables, diff --git a/test/TensorFlowNET.UnitTest/ConstantTest.cs b/test/TensorFlowNET.UnitTest/ConstantTest.cs index 344e4374..cb3ea87a 100644 --- a/test/TensorFlowNET.UnitTest/ConstantTest.cs +++ b/test/TensorFlowNET.UnitTest/ConstantTest.cs @@ -160,7 +160,6 @@ namespace TensorFlowNET.UnitTest.Basics Assert.AreEqual(6.0, (double)c); } - [Ignore] [TestMethod] public void StringEncode() { @@ -175,7 +174,7 @@ namespace TensorFlowNET.UnitTest.Basics string encoded_str = Marshal.PtrToStringUTF8(dst + sizeof(byte)); Assert.AreEqual(encoded_str, str); Assert.AreEqual(str.Length, Marshal.ReadByte(dst)); - //c_api.TF_StringDecode(dst, (ulong)str.Length, IntPtr.Zero, ref dst_len, status); + // c_api.TF_StringDecode(dst, (ulong)str.Length, IntPtr.Zero, ref dst_len, status.Handle); } [TestMethod] diff --git a/test/TensorFlowNET.UnitTest/ImageTest.cs b/test/TensorFlowNET.UnitTest/ImageTest.cs index d94101cc..02ae5e43 100644 --- a/test/TensorFlowNET.UnitTest/ImageTest.cs +++ b/test/TensorFlowNET.UnitTest/ImageTest.cs @@ -2,8 +2,10 @@ using System; using System.Collections.Generic; using System.IO; +using System.Reflection; using System.Text; using Tensorflow; +using Tensorflow.UnitTest; using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest.Basics @@ -20,11 +22,10 @@ namespace TensorFlowNET.UnitTest.Basics [TestInitialize] public void Initialize() { - imgPath = Path.GetFullPath(imgPath); - contents = tf.read_file(imgPath); + imgPath = TestHelper.GetFullPathFromDataDir(imgPath); + contents = tf.io.read_file(imgPath); } - [Ignore("")] [TestMethod] public void decode_image() { diff --git a/test/TensorFlowNET.UnitTest/nn_test/ActivationFunctionTest.cs b/test/TensorFlowNET.UnitTest/TF_API/ActivationFunctionTest.cs similarity index 100% rename from test/TensorFlowNET.UnitTest/nn_test/ActivationFunctionTest.cs rename to test/TensorFlowNET.UnitTest/TF_API/ActivationFunctionTest.cs diff --git a/test/TensorFlowNET.UnitTest/math_test/MathOperationTest.cs b/test/TensorFlowNET.UnitTest/TF_API/MathApiTest.cs similarity index 91% rename from test/TensorFlowNET.UnitTest/math_test/MathOperationTest.cs rename to test/TensorFlowNET.UnitTest/TF_API/MathApiTest.cs index ccc9c2d9..12023bd4 100644 --- a/test/TensorFlowNET.UnitTest/math_test/MathOperationTest.cs +++ b/test/TensorFlowNET.UnitTest/TF_API/MathApiTest.cs @@ -6,10 +6,10 @@ using System.Text; using Tensorflow; using static Tensorflow.Binding; -namespace TensorFlowNET.UnitTest.math_test +namespace TensorFlowNET.UnitTest.TF_API { [TestClass] - public class MathOperationTest : TFNetApiTest + public class MathApiTest : TFNetApiTest { // A constant vector of size 6 Tensor a = tf.constant(new float[] { 1.0f, -0.5f, 3.4f, -2.1f, 0.0f, -6.5f }); diff --git a/test/TensorFlowNET.UnitTest/TF_API/StringsApiTest.cs b/test/TensorFlowNET.UnitTest/TF_API/StringsApiTest.cs new file mode 100644 index 00000000..3049505b --- /dev/null +++ b/test/TensorFlowNET.UnitTest/TF_API/StringsApiTest.cs @@ -0,0 +1,43 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.UnitTest.TF_API +{ + [TestClass] + public class StringsApiTest + { + [TestMethod] + public void StringEqual() + { + var str1 = tf.constant("Hello1"); + var str2 = tf.constant("Hello2"); + var result = tf.equal(str1, str2); + Assert.IsFalse(result.ToScalar()); + + var str3 = tf.constant("Hello1"); + result = tf.equal(str1, str3); + Assert.IsTrue(result.ToScalar()); + + var str4 = tf.strings.substr(str1, 0, 5); + var str5 = tf.strings.substr(str2, 0, 5); + result = tf.equal(str4, str5); + Assert.IsTrue(result.ToScalar()); + } + + [TestMethod] + public void ImageType() + { + var imgPath = TestHelper.GetFullPathFromDataDir("shasta-daisy.jpg"); + var contents = tf.io.read_file(imgPath); + + var substr = tf.strings.substr(contents, 0, 3); + var jpg = Encoding.UTF8.GetString(new byte[] { 0xff, 0xd8, 0xff }); + var jpg_tensor = tf.constant(jpg); + + var result = math_ops.equal(substr, jpg_tensor); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/TFNetApiTest.cs b/test/TensorFlowNET.UnitTest/TF_API/TFNetApiTest.cs similarity index 100% rename from test/TensorFlowNET.UnitTest/TFNetApiTest.cs rename to test/TensorFlowNET.UnitTest/TF_API/TFNetApiTest.cs diff --git a/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs b/test/TensorFlowNET.UnitTest/TF_API/ZeroFractionTest.cs similarity index 100% rename from test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs rename to test/TensorFlowNET.UnitTest/TF_API/ZeroFractionTest.cs diff --git a/test/TensorFlowNET.UnitTest/nn_test/nn_test.py b/test/TensorFlowNET.UnitTest/TF_API/nn_test.py similarity index 100% rename from test/TensorFlowNET.UnitTest/nn_test/nn_test.py rename to test/TensorFlowNET.UnitTest/TF_API/nn_test.py diff --git a/test/TensorFlowNET.UnitTest/Utilities/TestHelper.cs b/test/TensorFlowNET.UnitTest/Utilities/TestHelper.cs new file mode 100644 index 00000000..dbc0d3a6 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Utilities/TestHelper.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; + +namespace Tensorflow.UnitTest +{ + public class TestHelper + { + public static string GetFullPathFromDataDir(string fileName) + { + var dir = Path.Combine(Directory.GetCurrentDirectory(), "..", "..", "..", "..", "..", "data"); + return Path.GetFullPath(Path.Combine(dir, fileName)); + } + } +}