| @@ -93,7 +93,12 @@ namespace Tensorflow | |||||
| => random_ops.random_shuffle(value, seed: seed, name: name); | => random_ops.random_shuffle(value, seed: seed, name: name); | ||||
| public void set_random_seed(int seed) | public void set_random_seed(int seed) | ||||
| => ops.get_default_graph().seed = seed; | |||||
| { | |||||
| if (executing_eagerly()) | |||||
| Context.set_global_seed(seed); | |||||
| else | |||||
| ops.get_default_graph().seed = seed; | |||||
| } | |||||
| public Tensor multinomial(Tensor logits, int num_samples, int? seed = null, | public Tensor multinomial(Tensor logits, int num_samples, int? seed = null, | ||||
| string name = null, TF_DataType output_dtype = TF_DataType.DtInvalid) | string name = null, TF_DataType output_dtype = TF_DataType.DtInvalid) | ||||
| @@ -42,6 +42,8 @@ namespace Tensorflow.Contexts | |||||
| SafeContextHandle _handle; | SafeContextHandle _handle; | ||||
| public SafeContextHandle Handle => _handle; | public SafeContextHandle Handle => _handle; | ||||
| int? _seed; | |||||
| public Context() | public Context() | ||||
| { | { | ||||
| _device_policy = ContextDevicePlacementPolicy.DEVICE_PLACEMENT_SILENT; | _device_policy = ContextDevicePlacementPolicy.DEVICE_PLACEMENT_SILENT; | ||||
| @@ -71,6 +73,12 @@ namespace Tensorflow.Contexts | |||||
| initialized = true; | initialized = true; | ||||
| } | } | ||||
| public void set_global_seed(int? seed) | |||||
| => _seed = seed; | |||||
| public int? global_seed() | |||||
| => _seed; | |||||
| public void start_step() | public void start_step() | ||||
| => c_api.TFE_ContextStartStep(_handle); | => c_api.TFE_ContextStartStep(_handle); | ||||
| @@ -14,6 +14,8 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class random_seed | public class random_seed | ||||
| @@ -22,8 +24,18 @@ namespace Tensorflow | |||||
| public static (int?, int?) get_seed(int? op_seed = null) | public static (int?, int?) get_seed(int? op_seed = null) | ||||
| { | { | ||||
| int? global_seed; | |||||
| if (tf.executing_eagerly()) | |||||
| global_seed = tf.Context.global_seed(); | |||||
| else | |||||
| global_seed = ops.get_default_graph().seed; | |||||
| if (global_seed.HasValue) | |||||
| return (global_seed, op_seed); | |||||
| if (op_seed.HasValue) | if (op_seed.HasValue) | ||||
| return (DEFAULT_GRAPH_SEED, 0); | |||||
| return (DEFAULT_GRAPH_SEED, op_seed); | |||||
| else | else | ||||
| return (null, null); | return (null, null); | ||||
| } | } | ||||