From e748a8ff0741846becbe533b3564c0b73f8d56d1 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 25 Jul 2021 10:25:42 -0500 Subject: [PATCH] overload ndarray == operator --- src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs | 3 +++ src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs | 10 ++++++++-- src/TensorFlowNET.Core/Training/Saving/Saver.cs | 5 ++++- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs index 5935db0f..376183f3 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs +++ b/src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs @@ -22,5 +22,8 @@ namespace Tensorflow.NumPy public static NDArray operator <(NDArray lhs, NDArray rhs) => new NDArray(gen_math_ops.less(lhs, rhs)); [AutoNumPy] public static NDArray operator -(NDArray lhs) => new NDArray(gen_math_ops.neg(lhs)); + [AutoNumPy] + public static bool operator ==(NDArray lhs, NDArray rhs) => rhs is null ? false : (bool)math_ops.equal(lhs, rhs); + public static bool operator !=(NDArray lhs, NDArray rhs) => !(lhs == rhs); } } diff --git a/src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs b/src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs index c15cd685..61141cd0 100644 --- a/src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs +++ b/src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs @@ -62,10 +62,16 @@ namespace Tensorflow.NumPy [AutoNumPy] public static NDArray load(string file) => tf.numpy.load(file); + public static T Load(string path) + where T : class, ICloneable, IList, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable + { + using (var stream = new FileStream(path, FileMode.Open)) + return Load(stream); + } + [AutoNumPy] public static T Load(Stream stream) - where T : class, - ICloneable, IList, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable + where T : class, ICloneable, IList, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable => tf.numpy.Load(stream); [AutoNumPy] diff --git a/src/TensorFlowNET.Core/Training/Saving/Saver.cs b/src/TensorFlowNET.Core/Training/Saving/Saver.cs index 6138dba4..85a3ee7d 100644 --- a/src/TensorFlowNET.Core/Training/Saving/Saver.cs +++ b/src/TensorFlowNET.Core/Training/Saving/Saver.cs @@ -211,7 +211,10 @@ namespace Tensorflow export_meta_graph(meta_graph_filename, strip_default_attrs: strip_default_attrs, save_debug_info: save_debug_info); } - return _is_empty ? string.Empty : model_checkpoint_path[0].StringData()[0]; + return checkpoint_file; + //var x = model_checkpoint_path[0]; + //var str = x.StringData(); + //return _is_empty ? string.Empty : model_checkpoint_path[0].StringData()[0]; } public (Saver, object) import_meta_graph(string meta_graph_or_file,