From 07bc287e8821a4b264b7b69979d6c3a2bd439458 Mon Sep 17 00:00:00 2001 From: dogvane Date: Wed, 4 Mar 2020 22:05:25 +0800 Subject: [PATCH] 1. fix adamOptimizer no init dtype. --- src/TensorFlowNET.Core/APIs/tf.train.cs | 3 +++ src/TensorFlowNET.Core/Training/AdamOptimizer.cs | 15 +++++++++------ src/TensorFlowNET.Core/Training/Optimizer.cs | 1 + 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.train.cs b/src/TensorFlowNET.Core/APIs/tf.train.cs index b9bc430d..3d325e8c 100644 --- a/src/TensorFlowNET.Core/APIs/tf.train.cs +++ b/src/TensorFlowNET.Core/APIs/tf.train.cs @@ -41,6 +41,9 @@ namespace Tensorflow public Optimizer AdamOptimizer(float learning_rate, string name = "Adam") => new AdamOptimizer(learning_rate, name: name); + public Optimizer AdamOptimizer(float learning_rate, TF_DataType dtype, string name = "Adam") + => new AdamOptimizer(learning_rate, name: name, dtype: dtype); + public Optimizer AdamOptimizer(Tensor learning_rate, string name = "Adam") => new AdamOptimizer(learning_rate, name: name); diff --git a/src/TensorFlowNET.Core/Training/AdamOptimizer.cs b/src/TensorFlowNET.Core/Training/AdamOptimizer.cs index 39228691..54c83cfb 100644 --- a/src/TensorFlowNET.Core/Training/AdamOptimizer.cs +++ b/src/TensorFlowNET.Core/Training/AdamOptimizer.cs @@ -32,21 +32,24 @@ namespace Tensorflow.Train float _beta2; float _epsilon; Tensor _beta1_t, _beta2_t, _epsilon_t; + TF_DataType _dtype; - public AdamOptimizer(float learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, string name = "Adam") + public AdamOptimizer(float learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "Adam") : base(learning_rate, use_locking, name) { _beta1 = beta1; _beta2 = beta2; _epsilon = epsilon; + _dtype = dtype; } - public AdamOptimizer(Tensor learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, string name = "Adam") + public AdamOptimizer(Tensor learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "Adam") : base(learning_rate, use_locking, name) { _beta1 = beta1; _beta2 = beta2; _epsilon = epsilon; + _dtype = dtype; } public override Operation _apply_sparse(IndexedSlices grad, RefVariable var) @@ -154,10 +157,10 @@ namespace Tensorflow.Train var beta2 = _call_if_callable(_beta2); var epsilon = _call_if_callable(_epsilon); - _lr_t = _lr_t ?? ops.convert_to_tensor(lr, name: "learning_rate"); - _beta1_t = _beta1_t ?? ops.convert_to_tensor(beta1, name: "beta1"); - _beta2_t = _beta2_t ?? ops.convert_to_tensor(beta2, name: "beta2"); - _epsilon_t = _epsilon_t ?? ops.convert_to_tensor(epsilon, name: "epsilon"); + _lr_t = _lr_t ?? ops.convert_to_tensor(lr, name: "learning_rate", dtype: _dtype); + _beta1_t = _beta1_t ?? ops.convert_to_tensor(beta1, name: "beta1", dtype: _dtype); + _beta2_t = _beta2_t ?? ops.convert_to_tensor(beta2, name: "beta2", dtype: _dtype); + _epsilon_t = _epsilon_t ?? ops.convert_to_tensor(epsilon, name: "epsilon", dtype: _dtype); } } } diff --git a/src/TensorFlowNET.Core/Training/Optimizer.cs b/src/TensorFlowNET.Core/Training/Optimizer.cs index 04ec949c..5272da3b 100644 --- a/src/TensorFlowNET.Core/Training/Optimizer.cs +++ b/src/TensorFlowNET.Core/Training/Optimizer.cs @@ -253,6 +253,7 @@ namespace Tensorflow v = variable_scope.default_variable_creator( initial_value, name: name, + dtype: colocate_with.dtype.as_base_dtype(), trainable: false, use_resource: resource_variable_ops.is_resource_variable( colocate_with));