diff --git a/src/TensorFlowNET.Core/Keras/BackendBase.cs b/src/TensorFlowNET.Core/Keras/BackendBase.cs index 0b553644..50263957 100644 --- a/src/TensorFlowNET.Core/Keras/BackendBase.cs +++ b/src/TensorFlowNET.Core/Keras/BackendBase.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using System; +using static Tensorflow.Python; namespace Tensorflow.Keras { @@ -43,9 +44,9 @@ namespace Tensorflow.Keras { if (value == null) value = _IMAGE_DATA_FORMAT; - if (value.GetType() == typeof(ImageDataFormat)) + if (isinstance(value, typeof(ImageDataFormat))) return (ImageDataFormat)value; - else if (value.GetType() == typeof(string)) + else if (isinstance(value, typeof(string))) { ImageDataFormat dataFormat; if(Enum.TryParse((string)value, true, out dataFormat)) diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index 934ae7d2..0b41e811 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -141,7 +141,7 @@ namespace Tensorflow dtype = input_arg.Type; else if (attrs.ContainsKey(input_arg.TypeAttr)) dtype = (DataType)attrs[input_arg.TypeAttr]; - else if (values.GetType() == typeof(string) && dtype == DataType.DtInvalid) + else if (isinstance(values, typeof(string)) && dtype == DataType.DtInvalid) dtype = DataType.DtString; else if (default_type_attr_map.ContainsKey(input_arg.TypeAttr)) default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr]; diff --git a/src/TensorFlowNET.Core/Python.cs b/src/TensorFlowNET.Core/Python.cs index e4585c61..0122aafb 100644 --- a/src/TensorFlowNET.Core/Python.cs +++ b/src/TensorFlowNET.Core/Python.cs @@ -309,6 +309,36 @@ namespace Tensorflow } return (__object__)((object[] args) => { return "NaN"; }); } + public static IEnumerable TupleToEnumerable(object tuple) + { + Type t = tuple.GetType(); + if(t.IsGenericType && (t.FullName.StartsWith("System.Tuple") || t.FullName.StartsWith("System.ValueTuple"))) + { + var flds = t.GetFields(); + for(int i = 0; i < flds.Length;i++) + { + yield return ((object)flds[i].GetValue(tuple)); + } + } + else + { + throw new System.Exception("Expected Tuple."); + } + } + public static bool isinstance(object Item1, Type Item2) + { + return (Item1.GetType() == Item2); + } + public static bool isinstance(object Item1, object tuple) + { + var tup = TupleToEnumerable(tuple); + foreach(var t in tup) + { + if(isinstance(Item1, (Type)t)) + return true; + } + return false; + } } public interface IPython : IDisposable diff --git a/test/TensorFlowNET.UnitTest/PythonBaseTests.cs b/test/TensorFlowNET.UnitTest/PythonBaseTests.cs index dff62549..436f5154 100644 --- a/test/TensorFlowNET.UnitTest/PythonBaseTests.cs +++ b/test/TensorFlowNET.UnitTest/PythonBaseTests.cs @@ -25,6 +25,31 @@ namespace TensorFlowNET.UnitTest Assert.IsTrue(b); } + [TestMethod] + public void isinstance_test() + { + var s1 = "hi"; + var s2 = "hello"; + + var t1 = (s1, s2); + var t2 = (s1, s2, s1); + var t3 = (s2, s1); + + var true1 = isinstance(s1, typeof(string)); + var false1 = isinstance(t1, typeof(string)); + var true2 = isinstance(t1, t3.GetType()); + var false2 = isinstance(t1, t2.GetType()); + var true3 = isinstance(t1, (t2.GetType(), t1.GetType(), typeof(string))); + var false3 = isinstance(t3, (t2.GetType(), typeof(string))); + + Assert.IsTrue(true1); + Assert.IsTrue(true2); + Assert.IsTrue(true3); + Assert.IsFalse(false1); + Assert.IsFalse(false2); + Assert.IsFalse(false3); + } + [TestMethod] public void hasattr_getattr() {