| @@ -43,6 +43,7 @@ namespace Tensorflow.Contexts | |||
| public SafeContextHandle Handle => _handle; | |||
| int? _seed; | |||
| Random _rng; | |||
| public Context() | |||
| { | |||
| @@ -74,11 +75,23 @@ namespace Tensorflow.Contexts | |||
| } | |||
| public void set_global_seed(int? seed) | |||
| => _seed = seed; | |||
| { | |||
| _seed = seed; | |||
| if (seed.HasValue) | |||
| _rng = new Random(seed.Value); | |||
| else | |||
| _rng = null; | |||
| // Also clear the kernel cache, to reset any existing seeds | |||
| if (_handle != null) | |||
| c_api.TFE_ContextClearCaches(_handle); | |||
| } | |||
| public int? global_seed() | |||
| => _seed; | |||
| public int? internal_operation_seed() | |||
| => _rng?.Next(0, int.MaxValue); | |||
| public void start_step() | |||
| => c_api.TFE_ContextStartStep(_handle); | |||
| @@ -94,7 +107,7 @@ namespace Tensorflow.Contexts | |||
| { | |||
| if(context_switches.Count() == 0) | |||
| tf.enable_eager_execution(); | |||
| return context_switches.Current().EagerMode; | |||
| } | |||
| @@ -14,6 +14,7 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System.Collections.Generic; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| @@ -21,6 +22,7 @@ namespace Tensorflow | |||
| public class random_seed | |||
| { | |||
| private static int DEFAULT_GRAPH_SEED = 87654321; | |||
| private static Dictionary<string, int> _graph_to_seed_dict = new Dictionary<string, int>(); | |||
| public static (int?, int?) get_seed(int? op_seed = null) | |||
| { | |||
| @@ -32,7 +34,20 @@ namespace Tensorflow | |||
| global_seed = ops.get_default_graph().seed; | |||
| if (global_seed.HasValue) | |||
| { | |||
| if (!op_seed.HasValue) | |||
| if (tf.executing_eagerly()) | |||
| op_seed = tf.Context.internal_operation_seed(); | |||
| else | |||
| { | |||
| if (!_graph_to_seed_dict.TryGetValue(ops.get_default_graph().graph_key, out int seed)) | |||
| seed = 0; | |||
| _graph_to_seed_dict[ops.get_default_graph().graph_key] = seed + 1; | |||
| op_seed = seed; | |||
| } | |||
| return (global_seed, op_seed); | |||
| } | |||
| if (op_seed.HasValue) | |||
| return (DEFAULT_GRAPH_SEED, op_seed); | |||
| @@ -131,7 +131,9 @@ namespace Tensorflow | |||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "RandomShuffle", name, | |||
| null, | |||
| value, seed, seed2); | |||
| value, | |||
| "seed", seed, | |||
| "seed2", seed2); | |||
| return results[0]; | |||
| } | |||
| @@ -14,7 +14,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||
| /// Test the function of setting random seed | |||
| /// This will help regenerate the same result | |||
| /// </summary> | |||
| [TestMethod, Ignore] | |||
| [TestMethod] | |||
| public void TFRandomSeedTest() | |||
| { | |||
| var initValue = np.arange(6).reshape(3, 2); | |||
| @@ -37,7 +37,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||
| /// <summary> | |||
| /// compare to Test above, seed is also added in params | |||
| /// </summary> | |||
| [TestMethod, Ignore] | |||
| [TestMethod] | |||
| public void TFRandomSeedTest2() | |||
| { | |||
| var initValue = np.arange(6).reshape(3, 2); | |||
| @@ -60,7 +60,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||
| /// <summary> | |||
| /// This part we use funcs in tf.random rather than only tf | |||
| /// </summary> | |||
| [TestMethod, Ignore] | |||
| [TestMethod] | |||
| public void TFRandomRaodomSeedTest() | |||
| { | |||
| tf.set_random_seed(1234); | |||
| @@ -83,7 +83,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||
| /// <summary> | |||
| /// compare to Test above, seed is also added in params | |||
| /// </summary> | |||
| [TestMethod, Ignore] | |||
| [TestMethod] | |||
| public void TFRandomRaodomSeedTest2() | |||
| { | |||
| tf.set_random_seed(1234); | |||