From f3a5d190872a9b4d6f96dcd7430628375885b536 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 28 Aug 2021 12:47:39 -0500 Subject: [PATCH] fix AutoNumPy crash when in multiple thread. --- src/TensorFlowNET.Core/Graphs/Graph.Import.cs | 1 - src/TensorFlowNET.Core/NumPy/AutoNumPyAttribute.cs | 9 ++++++++- src/TensorFlowNET.Core/Numpy/NDArray.cs | 1 + src/TensorFlowNET.Core/Operations/math_ops.cs | 1 - src/TensorFlowNET.Core/Tensors/constant_op.cs | 2 +- test/TensorFlowNET.UnitTest/Utilities/FluentExtension.cs | 7 ------- 6 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Import.cs b/src/TensorFlowNET.Core/Graphs/Graph.Import.cs index 53c37218..28ecd64e 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Import.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Import.cs @@ -52,7 +52,6 @@ namespace Tensorflow using (var status = new Status()) using (var graph_def = new Buffer(bytes)) { - as_default(); c_api.TF_ImportGraphDefOptionsSetPrefix(opts.Handle, prefix); c_api.TF_GraphImportGraphDef(_handle, graph_def.Handle, opts.Handle, status.Handle); status.Check(true); diff --git a/src/TensorFlowNET.Core/NumPy/AutoNumPyAttribute.cs b/src/TensorFlowNET.Core/NumPy/AutoNumPyAttribute.cs index 5a551609..94828922 100644 --- a/src/TensorFlowNET.Core/NumPy/AutoNumPyAttribute.cs +++ b/src/TensorFlowNET.Core/NumPy/AutoNumPyAttribute.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using System.Threading; using static Tensorflow.Binding; namespace Tensorflow.NumPy @@ -10,9 +11,12 @@ namespace Tensorflow.NumPy public sealed class AutoNumPyAttribute : OnMethodBoundaryAspect { bool _changedMode = false; - + bool _locked = false; + static object locker = new Object(); public override void OnEntry(MethodExecutionArgs args) { + Monitor.Enter(locker, ref _locked); + if (!tf.executing_eagerly()) { tf.Context.eager_mode(); @@ -24,6 +28,9 @@ namespace Tensorflow.NumPy { if (_changedMode) tf.Context.restore_mode(); + + if (_locked) + Monitor.Exit(locker); } } } diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.cs b/src/TensorFlowNET.Core/Numpy/NDArray.cs index edb4d292..3a2cb3ee 100644 --- a/src/TensorFlowNET.Core/Numpy/NDArray.cs +++ b/src/TensorFlowNET.Core/Numpy/NDArray.cs @@ -30,6 +30,7 @@ namespace Tensorflow.NumPy [AutoNumPy] public NDArray reshape(Shape newshape) => new NDArray(tf.reshape(this, newshape)); + [AutoNumPy] public NDArray astype(TF_DataType dtype) => new NDArray(math_ops.cast(this, dtype)); public NDArray ravel() => throw new NotImplementedException(""); public void shuffle(NDArray nd) => np.random.shuffle(nd); diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index 97b9d13f..3411308c 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -126,7 +126,6 @@ namespace Tensorflow return tf_with(ops.name_scope(name, "Cast", new { x }), scope => { name = scope; - x = ops.convert_to_tensor(x, name: "x"); if (x.dtype.as_base_dtype() != base_type) x = gen_math_ops.cast(x, base_type, name: name); diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index 0dccb955..2c903517 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -101,7 +101,7 @@ namespace Tensorflow value is NDArray nd && nd.dtype != dtype) { - value = nd.astype(dtype); + value = math_ops.cast(nd, dtype); } // non ascii char diff --git a/test/TensorFlowNET.UnitTest/Utilities/FluentExtension.cs b/test/TensorFlowNET.UnitTest/Utilities/FluentExtension.cs index 49d32261..41d8ab03 100644 --- a/test/TensorFlowNET.UnitTest/Utilities/FluentExtension.cs +++ b/test/TensorFlowNET.UnitTest/Utilities/FluentExtension.cs @@ -197,13 +197,6 @@ namespace TensorFlowNET.UnitTest return new AndConstraint(this); } - public AndConstraint BeScalar(object value) - { - Subject.shape.IsScalar.Should().BeTrue(); - Subject.GetValue().Should().Be(value); - return new AndConstraint(this); - } - public AndConstraint BeOfType(Type typeCode) { Subject.dtype.Should().Be(typeCode);