| @@ -15,6 +15,7 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System; | using System; | ||||
| using static Tensorflow.Python; | |||||
| namespace Tensorflow.Keras | namespace Tensorflow.Keras | ||||
| { | { | ||||
| @@ -43,9 +44,9 @@ namespace Tensorflow.Keras | |||||
| { | { | ||||
| if (value == null) | if (value == null) | ||||
| value = _IMAGE_DATA_FORMAT; | value = _IMAGE_DATA_FORMAT; | ||||
| if (value.GetType() == typeof(ImageDataFormat)) | |||||
| if (isinstance(value, typeof(ImageDataFormat))) | |||||
| return (ImageDataFormat)value; | return (ImageDataFormat)value; | ||||
| else if (value.GetType() == typeof(string)) | |||||
| else if (isinstance(value, typeof(string))) | |||||
| { | { | ||||
| ImageDataFormat dataFormat; | ImageDataFormat dataFormat; | ||||
| if(Enum.TryParse((string)value, true, out dataFormat)) | if(Enum.TryParse((string)value, true, out dataFormat)) | ||||
| @@ -141,7 +141,7 @@ namespace Tensorflow | |||||
| dtype = input_arg.Type; | dtype = input_arg.Type; | ||||
| else if (attrs.ContainsKey(input_arg.TypeAttr)) | else if (attrs.ContainsKey(input_arg.TypeAttr)) | ||||
| dtype = (DataType)attrs[input_arg.TypeAttr]; | dtype = (DataType)attrs[input_arg.TypeAttr]; | ||||
| else if (values.GetType() == typeof(string) && dtype == DataType.DtInvalid) | |||||
| else if (isinstance(values, typeof(string)) && dtype == DataType.DtInvalid) | |||||
| dtype = DataType.DtString; | dtype = DataType.DtString; | ||||
| else if (default_type_attr_map.ContainsKey(input_arg.TypeAttr)) | else if (default_type_attr_map.ContainsKey(input_arg.TypeAttr)) | ||||
| default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr]; | default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr]; | ||||
| @@ -309,6 +309,36 @@ namespace Tensorflow | |||||
| } | } | ||||
| return (__object__)((object[] args) => { return "NaN"; }); | return (__object__)((object[] args) => { return "NaN"; }); | ||||
| } | } | ||||
| public static IEnumerable TupleToEnumerable(object tuple) | |||||
| { | |||||
| Type t = tuple.GetType(); | |||||
| if(t.IsGenericType && (t.FullName.StartsWith("System.Tuple") || t.FullName.StartsWith("System.ValueTuple"))) | |||||
| { | |||||
| var flds = t.GetFields(); | |||||
| for(int i = 0; i < flds.Length;i++) | |||||
| { | |||||
| yield return ((object)flds[i].GetValue(tuple)); | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new System.Exception("Expected Tuple."); | |||||
| } | |||||
| } | |||||
| public static bool isinstance(object Item1, Type Item2) | |||||
| { | |||||
| return (Item1.GetType() == Item2); | |||||
| } | |||||
| public static bool isinstance(object Item1, object tuple) | |||||
| { | |||||
| var tup = TupleToEnumerable(tuple); | |||||
| foreach(var t in tup) | |||||
| { | |||||
| if(isinstance(Item1, (Type)t)) | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } | } | ||||
| public interface IPython : IDisposable | public interface IPython : IDisposable | ||||
| @@ -25,6 +25,31 @@ namespace TensorFlowNET.UnitTest | |||||
| Assert.IsTrue(b); | Assert.IsTrue(b); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void isinstance_test() | |||||
| { | |||||
| var s1 = "hi"; | |||||
| var s2 = "hello"; | |||||
| var t1 = (s1, s2); | |||||
| var t2 = (s1, s2, s1); | |||||
| var t3 = (s2, s1); | |||||
| var true1 = isinstance(s1, typeof(string)); | |||||
| var false1 = isinstance(t1, typeof(string)); | |||||
| var true2 = isinstance(t1, t3.GetType()); | |||||
| var false2 = isinstance(t1, t2.GetType()); | |||||
| var true3 = isinstance(t1, (t2.GetType(), t1.GetType(), typeof(string))); | |||||
| var false3 = isinstance(t3, (t2.GetType(), typeof(string))); | |||||
| Assert.IsTrue(true1); | |||||
| Assert.IsTrue(true2); | |||||
| Assert.IsTrue(true3); | |||||
| Assert.IsFalse(false1); | |||||
| Assert.IsFalse(false2); | |||||
| Assert.IsFalse(false3); | |||||
| } | |||||
| [TestMethod] | [TestMethod] | ||||
| public void hasattr_getattr() | public void hasattr_getattr() | ||||
| { | { | ||||