| @@ -43,6 +43,7 @@ namespace Tensorflow.Contexts | |||||
| public SafeContextHandle Handle => _handle; | public SafeContextHandle Handle => _handle; | ||||
| int? _seed; | int? _seed; | ||||
| Random _rng; | |||||
| public Context() | public Context() | ||||
| { | { | ||||
| @@ -74,11 +75,23 @@ namespace Tensorflow.Contexts | |||||
| } | } | ||||
| public void set_global_seed(int? seed) | public void set_global_seed(int? seed) | ||||
| => _seed = seed; | |||||
| { | |||||
| _seed = seed; | |||||
| if (seed.HasValue) | |||||
| _rng = new Random(seed.Value); | |||||
| else | |||||
| _rng = null; | |||||
| // Also clear the kernel cache, to reset any existing seeds | |||||
| if (_handle != null) | |||||
| c_api.TFE_ContextClearCaches(_handle); | |||||
| } | |||||
| public int? global_seed() | public int? global_seed() | ||||
| => _seed; | => _seed; | ||||
| public int? internal_operation_seed() | |||||
| => _rng?.Next(0, int.MaxValue); | |||||
| public void start_step() | public void start_step() | ||||
| => c_api.TFE_ContextStartStep(_handle); | => c_api.TFE_ContextStartStep(_handle); | ||||
| @@ -94,7 +107,7 @@ namespace Tensorflow.Contexts | |||||
| { | { | ||||
| if(context_switches.Count() == 0) | if(context_switches.Count() == 0) | ||||
| tf.enable_eager_execution(); | tf.enable_eager_execution(); | ||||
| return context_switches.Current().EagerMode; | return context_switches.Current().EagerMode; | ||||
| } | } | ||||
| @@ -14,6 +14,7 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Collections.Generic; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -21,6 +22,7 @@ namespace Tensorflow | |||||
| public class random_seed | public class random_seed | ||||
| { | { | ||||
| private static int DEFAULT_GRAPH_SEED = 87654321; | private static int DEFAULT_GRAPH_SEED = 87654321; | ||||
| private static Dictionary<string, int> _graph_to_seed_dict = new Dictionary<string, int>(); | |||||
| public static (int?, int?) get_seed(int? op_seed = null) | public static (int?, int?) get_seed(int? op_seed = null) | ||||
| { | { | ||||
| @@ -32,7 +34,20 @@ namespace Tensorflow | |||||
| global_seed = ops.get_default_graph().seed; | global_seed = ops.get_default_graph().seed; | ||||
| if (global_seed.HasValue) | if (global_seed.HasValue) | ||||
| { | |||||
| if (!op_seed.HasValue) | |||||
| if (tf.executing_eagerly()) | |||||
| op_seed = tf.Context.internal_operation_seed(); | |||||
| else | |||||
| { | |||||
| if (!_graph_to_seed_dict.TryGetValue(ops.get_default_graph().graph_key, out int seed)) | |||||
| seed = 0; | |||||
| _graph_to_seed_dict[ops.get_default_graph().graph_key] = seed + 1; | |||||
| op_seed = seed; | |||||
| } | |||||
| return (global_seed, op_seed); | return (global_seed, op_seed); | ||||
| } | |||||
| if (op_seed.HasValue) | if (op_seed.HasValue) | ||||
| return (DEFAULT_GRAPH_SEED, op_seed); | return (DEFAULT_GRAPH_SEED, op_seed); | ||||
| @@ -131,7 +131,9 @@ namespace Tensorflow | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | ||||
| "RandomShuffle", name, | "RandomShuffle", name, | ||||
| null, | null, | ||||
| value, seed, seed2); | |||||
| value, | |||||
| "seed", seed, | |||||
| "seed2", seed2); | |||||
| return results[0]; | return results[0]; | ||||
| } | } | ||||
| @@ -14,7 +14,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
| /// Test the function of setting random seed | /// Test the function of setting random seed | ||||
| /// This will help regenerate the same result | /// This will help regenerate the same result | ||||
| /// </summary> | /// </summary> | ||||
| [TestMethod, Ignore] | |||||
| [TestMethod] | |||||
| public void TFRandomSeedTest() | public void TFRandomSeedTest() | ||||
| { | { | ||||
| var initValue = np.arange(6).reshape(3, 2); | var initValue = np.arange(6).reshape(3, 2); | ||||
| @@ -37,7 +37,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
| /// <summary> | /// <summary> | ||||
| /// compare to Test above, seed is also added in params | /// compare to Test above, seed is also added in params | ||||
| /// </summary> | /// </summary> | ||||
| [TestMethod, Ignore] | |||||
| [TestMethod] | |||||
| public void TFRandomSeedTest2() | public void TFRandomSeedTest2() | ||||
| { | { | ||||
| var initValue = np.arange(6).reshape(3, 2); | var initValue = np.arange(6).reshape(3, 2); | ||||
| @@ -60,7 +60,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
| /// <summary> | /// <summary> | ||||
| /// This part we use funcs in tf.random rather than only tf | /// This part we use funcs in tf.random rather than only tf | ||||
| /// </summary> | /// </summary> | ||||
| [TestMethod, Ignore] | |||||
| [TestMethod] | |||||
| public void TFRandomRaodomSeedTest() | public void TFRandomRaodomSeedTest() | ||||
| { | { | ||||
| tf.set_random_seed(1234); | tf.set_random_seed(1234); | ||||
| @@ -83,7 +83,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
| /// <summary> | /// <summary> | ||||
| /// compare to Test above, seed is also added in params | /// compare to Test above, seed is also added in params | ||||
| /// </summary> | /// </summary> | ||||
| [TestMethod, Ignore] | |||||
| [TestMethod] | |||||
| public void TFRandomRaodomSeedTest2() | public void TFRandomRaodomSeedTest2() | ||||
| { | { | ||||
| tf.set_random_seed(1234); | tf.set_random_seed(1234); | ||||