Browse Source

1. fix adamOptimizer no init dtype.

pull/519/head
dogvane 5 years ago
parent
commit
07bc287e88
3 changed files with 13 additions and 6 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/APIs/tf.train.cs
  2. +9
    -6
      src/TensorFlowNET.Core/Training/AdamOptimizer.cs
  3. +1
    -0
      src/TensorFlowNET.Core/Training/Optimizer.cs

+ 3
- 0
src/TensorFlowNET.Core/APIs/tf.train.cs View File

@@ -41,6 +41,9 @@ namespace Tensorflow
public Optimizer AdamOptimizer(float learning_rate, string name = "Adam")
=> new AdamOptimizer(learning_rate, name: name);

public Optimizer AdamOptimizer(float learning_rate, TF_DataType dtype, string name = "Adam")
=> new AdamOptimizer(learning_rate, name: name, dtype: dtype);

public Optimizer AdamOptimizer(Tensor learning_rate, string name = "Adam")
=> new AdamOptimizer(learning_rate, name: name);



+ 9
- 6
src/TensorFlowNET.Core/Training/AdamOptimizer.cs View File

@@ -32,21 +32,24 @@ namespace Tensorflow.Train
float _beta2;
float _epsilon;
Tensor _beta1_t, _beta2_t, _epsilon_t;
TF_DataType _dtype;

public AdamOptimizer(float learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, string name = "Adam")
public AdamOptimizer(float learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "Adam")
: base(learning_rate, use_locking, name)
{
_beta1 = beta1;
_beta2 = beta2;
_epsilon = epsilon;
_dtype = dtype;
}

public AdamOptimizer(Tensor learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, string name = "Adam")
public AdamOptimizer(Tensor learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "Adam")
: base(learning_rate, use_locking, name)
{
_beta1 = beta1;
_beta2 = beta2;
_epsilon = epsilon;
_dtype = dtype;
}

public override Operation _apply_sparse(IndexedSlices grad, RefVariable var)
@@ -154,10 +157,10 @@ namespace Tensorflow.Train
var beta2 = _call_if_callable(_beta2);
var epsilon = _call_if_callable(_epsilon);

_lr_t = _lr_t ?? ops.convert_to_tensor(lr, name: "learning_rate");
_beta1_t = _beta1_t ?? ops.convert_to_tensor(beta1, name: "beta1");
_beta2_t = _beta2_t ?? ops.convert_to_tensor(beta2, name: "beta2");
_epsilon_t = _epsilon_t ?? ops.convert_to_tensor(epsilon, name: "epsilon");
_lr_t = _lr_t ?? ops.convert_to_tensor(lr, name: "learning_rate", dtype: _dtype);
_beta1_t = _beta1_t ?? ops.convert_to_tensor(beta1, name: "beta1", dtype: _dtype);
_beta2_t = _beta2_t ?? ops.convert_to_tensor(beta2, name: "beta2", dtype: _dtype);
_epsilon_t = _epsilon_t ?? ops.convert_to_tensor(epsilon, name: "epsilon", dtype: _dtype);
}
}
}

+ 1
- 0
src/TensorFlowNET.Core/Training/Optimizer.cs View File

@@ -253,6 +253,7 @@ namespace Tensorflow
v = variable_scope.default_variable_creator(
initial_value,
name: name,
dtype: colocate_with.dtype.as_base_dtype(),
trainable: false,
use_resource: resource_variable_ops.is_resource_variable(
colocate_with));


Loading…
Cancel
Save