diff --git a/src/TensorFlowNET.Core/APIs/tf.image.cs b/src/TensorFlowNET.Core/APIs/tf.image.cs index 61c91052..e2e3206b 100644 --- a/src/TensorFlowNET.Core/APIs/tf.image.cs +++ b/src/TensorFlowNET.Core/APIs/tf.image.cs @@ -42,6 +42,20 @@ namespace Tensorflow public Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool saturate = false, string name = null) => gen_image_ops.convert_image_dtype(image, dtype, saturate: saturate, name: name); + + public Tensor decode_image(Tensor contents, int channels = 0, TF_DataType dtype = TF_DataType.TF_UINT8, + string name = null, bool expand_animations = true) + => image_ops_impl.decode_image(contents, channels: channels, dtype: dtype, + name: name, expand_animations: expand_animations); + + /// + /// Convenience function to check if the 'contents' encodes a JPEG image. + /// + /// + /// + /// + public static Tensor is_jpeg(Tensor contents, string name = null) + => image_ops_impl.is_jpeg(contents, name: name); } } } diff --git a/src/TensorFlowNET.Core/APIs/tf.strings.cs b/src/TensorFlowNET.Core/APIs/tf.strings.cs new file mode 100644 index 00000000..38d92803 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.strings.cs @@ -0,0 +1,32 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; +using Tensorflow.IO; + +namespace Tensorflow +{ + public partial class tensorflow + { + public strings_internal strings = new strings_internal(); + public class strings_internal + { + public Tensor substr(Tensor input, int pos, int len, + string name = null, string @uint = "BYTE") + => string_ops.substr(input, pos, len, name: name, @uint: @uint); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs index 3cd1a83e..90893815 100644 --- a/src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs @@ -88,6 +88,69 @@ namespace Tensorflow } } + public static Tensor decode_gif(Tensor contents, + string name = null) + { + // Add nodes to the TensorFlow graph. + if (tf.context.executing_eagerly()) + { + throw new NotImplementedException("decode_gif"); + } + else + { + var _op = _op_def_lib._apply_op_helper("DecodeGif", name: name, args: new + { + contents + }); + + return _op.output; + } + } + + public static Tensor decode_png(Tensor contents, + int channels = 0, + TF_DataType dtype = TF_DataType.TF_UINT8, + string name = null) + { + // Add nodes to the TensorFlow graph. + if (tf.context.executing_eagerly()) + { + throw new NotImplementedException("decode_png"); + } + else + { + var _op = _op_def_lib._apply_op_helper("DecodePng", name: name, args: new + { + contents, + channels, + dtype + }); + + return _op.output; + } + } + + public static Tensor decode_bmp(Tensor contents, + int channels = 0, + string name = null) + { + // Add nodes to the TensorFlow graph. + if (tf.context.executing_eagerly()) + { + throw new NotImplementedException("decode_bmp"); + } + else + { + var _op = _op_def_lib._apply_op_helper("DecodeBmp", name: name, args: new + { + contents, + channels + }); + + return _op.output; + } + } + public static Tensor resize_bilinear(Tensor images, Tensor size, bool align_corners = false, string name = null) { if (tf.context.executing_eagerly()) diff --git a/src/TensorFlowNET.Core/Operations/gen_string_ops.cs b/src/TensorFlowNET.Core/Operations/gen_string_ops.cs new file mode 100644 index 00000000..87ac589e --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/gen_string_ops.cs @@ -0,0 +1,42 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class gen_string_ops + { + static readonly OpDefLibrary _op_def_lib; + static gen_string_ops() { _op_def_lib = new OpDefLibrary(); } + + public static Tensor substr(Tensor input, int pos, int len, + string name = null, string @uint = "BYTE") + { + var _op = _op_def_lib._apply_op_helper("Substr", name: name, args: new + { + input, + pos, + len, + unit = @uint + }); + + return _op.output; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs index d21f6e8d..65ed8eb1 100644 --- a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs +++ b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs @@ -17,11 +17,116 @@ using System; using System.Collections.Generic; using System.Text; +using static Tensorflow.Binding; namespace Tensorflow { public class image_ops_impl { + public static Tensor decode_image(Tensor contents, int channels = 0, TF_DataType dtype = TF_DataType.TF_UINT8, + string name = null, bool expand_animations = true) + { + Tensor substr = null; + Func _jpeg = () => + { + int jpeg_channels = channels; + var good_channels = math_ops.not_equal(jpeg_channels, 4, name: "check_jpeg_channels"); + string channels_msg = "Channels must be in (None, 0, 1, 3) when decoding JPEG 'images'"; + var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg }); + return tf_with(ops.control_dependencies(new[] { assert_channels }), delegate + { + return convert_image_dtype(gen_image_ops.decode_jpeg(contents, channels), dtype); + }); + }; + + Func _gif = () => + { + int gif_channels = channels; + var good_channels = math_ops.logical_and( + math_ops.not_equal(gif_channels, 1, name: "check_gif_channels"), + math_ops.not_equal(gif_channels, 4, name: "check_gif_channels")); + + string channels_msg = "Channels must be in (None, 0, 3) when decoding GIF images"; + var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg }); + return tf_with(ops.control_dependencies(new[] { assert_channels }), delegate + { + var result = convert_image_dtype(gen_image_ops.decode_gif(contents), dtype); + if (!expand_animations) + // result = array_ops.gather(result, 0); + throw new NotImplementedException(""); + return result; + }); + }; + + Func _bmp = () => + { + int bmp_channels = channels; + var signature = string_ops.substr(contents, 0, 2); + var is_bmp = math_ops.equal(signature, "BM", name: "is_bmp"); + string decode_msg = "Unable to decode bytes as JPEG, PNG, GIF, or BMP"; + var assert_decode = control_flow_ops.Assert(is_bmp, new string[] { decode_msg }); + var good_channels = math_ops.not_equal(bmp_channels, 1, name: "check_channels"); + string channels_msg = "Channels must be in (None, 0, 3) when decoding BMP images"; + var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg }); + return tf_with(ops.control_dependencies(new[] { assert_decode, assert_channels }), delegate + { + return convert_image_dtype(gen_image_ops.decode_bmp(contents), dtype); + }); + }; + + Func _png = () => + { + return convert_image_dtype(gen_image_ops.decode_png( + contents, + channels, + dtype: dtype), + dtype); + }; + + Func check_gif = () => + { + var is_gif = math_ops.equal(substr, "\x47\x49\x46", name: "is_gif"); + return control_flow_ops.cond(is_gif, _gif, _bmp, name: "cond_gif"); + }; + + Func check_png = () => + { + return control_flow_ops.cond(_is_png(contents), _png, check_gif, name: "cond_png"); + }; + + return tf_with(ops.name_scope(name, "decode_image"), scope => + { + substr = string_ops.substr(contents, 0, 3); + return control_flow_ops.cond(is_jpeg(contents), _jpeg, check_png, name: "cond_jpeg"); + }); + } + + public static Tensor is_jpeg(Tensor contents, string name = null) + { + return tf_with(ops.name_scope(name, "is_jpeg"), scope => + { + var substr = string_ops.substr(contents, 0, 3); + return math_ops.equal(substr, "\xff\xd8\xff", name: name); + }); + } + + public static Tensor _is_png(Tensor contents, string name = null) + { + return tf_with(ops.name_scope(name, "is_png"), scope => + { + var substr = string_ops.substr(contents, 0, 3); + return math_ops.equal(substr, @"\211PN", name: name); + }); + } + + public static Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool saturate = false, + string name = null) + { + if (dtype == image.dtype) + return array_ops.identity(image, name: name); + + throw new NotImplementedException(""); + } } } diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index fa1fda12..f5cfdb37 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -168,6 +168,9 @@ namespace Tensorflow public static Tensor multiply(Tx x, Ty y, string name = null) => gen_math_ops.mul(x, y, name: name); + public static Tensor not_equal(Tx x, Ty y, string name = null) + => gen_math_ops.not_equal(x, y, name: name); + public static Tensor mul_no_nan(Tx x, Ty y, string name = null) => gen_math_ops.mul_no_nan(x, y, name: name); @@ -264,6 +267,9 @@ namespace Tensorflow return gen_math_ops.log(x, name); } + public static Tensor logical_and(Tensor x, Tensor y, string name = null) + => gen_math_ops.logical_and(x, y, name: name); + public static Tensor lgamma(Tensor x, string name = null) => gen_math_ops.lgamma(x, name: name); diff --git a/src/TensorFlowNET.Core/Operations/string_ops.cs b/src/TensorFlowNET.Core/Operations/string_ops.cs new file mode 100644 index 00000000..ee46cf78 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/string_ops.cs @@ -0,0 +1,38 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class string_ops + { + /// + /// Return substrings from `Tensor` of strings. + /// + /// + /// + /// + /// + /// + /// + public static Tensor substr(Tensor input, int pos, int len, + string name = null, string @uint = "BYTE") + => gen_string_ops.substr(input, pos, len, name: name, @uint: @uint); + } +} diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs index f5431e01..c088d7e6 100644 --- a/test/TensorFlowNET.UnitTest/GraphTest.cs +++ b/test/TensorFlowNET.UnitTest/GraphTest.cs @@ -416,12 +416,13 @@ namespace TensorFlowNET.UnitTest } + [TestMethod] public void ImportGraphMeta() { var dir = "my-save-dir/"; using (var sess = tf.Session()) { - var new_saver = tf.train.import_meta_graph(dir + "my-model-10000.meta"); + var new_saver = tf.train.import_meta_graph(@"D:\tmp\resnet_v2_101_2017_04_14\eval.graph"); new_saver.restore(sess, dir + "my-model-10000"); var labels = tf.constant(0, dtype: tf.int32, shape: new int[] { 100 }, name: "labels"); var batch_size = tf.size(labels); diff --git a/test/TensorFlowNET.UnitTest/ImageTest.cs b/test/TensorFlowNET.UnitTest/ImageTest.cs new file mode 100644 index 00000000..4b6d5922 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/ImageTest.cs @@ -0,0 +1,30 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest +{ + [TestClass] + public class ImageTest + { + string imgPath = "../../../../../data/shasta-daisy.jpg"; + Tensor contents; + + public ImageTest() + { + imgPath = Path.GetFullPath(imgPath); + contents = tf.read_file(imgPath); + } + + [TestMethod] + public void decode_image() + { + var img = tf.image.decode_image(contents); + Assert.AreEqual(img.name, "decode_image/cond_jpeg/Merge:0"); + } + } +}