diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index c3e01278..38fb267c 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -182,12 +182,7 @@ namespace Tensorflow => nn_impl.sigmoid_cross_entropy_with_logits(labels: labels, logits: logits, name: name); public Tensor softmax(Tensor logits, int axis = -1, string name = null) - { - if (axis == -1) - return gen_nn_ops.softmax(logits, name); - else - throw new NotImplementedException(""); - } + => gen_nn_ops.softmax(logits, name); /// diff --git a/src/TensorFlowNET.Core/APIs/tf.train.cs b/src/TensorFlowNET.Core/APIs/tf.train.cs index df10e79e..3a9103ce 100644 --- a/src/TensorFlowNET.Core/APIs/tf.train.cs +++ b/src/TensorFlowNET.Core/APIs/tf.train.cs @@ -44,6 +44,9 @@ namespace Tensorflow public Optimizer AdamOptimizer(float learning_rate, TF_DataType dtype, string name = "Adam") => new AdamOptimizer(learning_rate, name: name, dtype: dtype); + public Optimizer AdamOptimizer(IVariableV1 learning_rate, string name = "Adam") + => new AdamOptimizer(learning_rate.AsTensor(), name: name); + public Optimizer AdamOptimizer(Tensor learning_rate, string name = "Adam") => new AdamOptimizer(learning_rate, name: name); diff --git a/src/TensorFlowNET.Core/Training/Optimizer.cs b/src/TensorFlowNET.Core/Training/Optimizer.cs index c9c1673c..d22562d5 100644 --- a/src/TensorFlowNET.Core/Training/Optimizer.cs +++ b/src/TensorFlowNET.Core/Training/Optimizer.cs @@ -207,7 +207,7 @@ namespace Tensorflow { apply_updates = state_ops.assign_add(global_step, ops.convert_to_tensor(1, dtype: global_step.dtype), - name: name) as Operation; + name: name); } }); } diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index 18b93ec3..5fe0043e 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -73,7 +73,7 @@ namespace Tensorflow // handle_deleter } - public ITensorOrOperation assign(T value, bool use_locking = false, string name = null, bool read_value = true) + public Tensor assign(T value, bool use_locking = false, string name = null, bool read_value = true) { if(value.GetType() == typeof(Tensor)) { @@ -134,7 +134,7 @@ namespace Tensorflow return array_ops.identity(value); }); - public ITensorOrOperation assign_add(T delta, bool use_locking = false, string name = null, bool read_value = true) + public Tensor assign_add(T delta, bool use_locking = false, string name = null, bool read_value = true) { var assign_add_op = gen_resource_variable_ops.assign_add_variable_op(Handle, ops.convert_to_tensor(delta, dtype: dtype), name: name); diff --git a/src/TensorFlowNET.Core/Variables/IVariableV1.cs b/src/TensorFlowNET.Core/Variables/IVariableV1.cs index cd76b092..52549ecc 100644 --- a/src/TensorFlowNET.Core/Variables/IVariableV1.cs +++ b/src/TensorFlowNET.Core/Variables/IVariableV1.cs @@ -47,8 +47,8 @@ namespace Tensorflow public Graph Graph { get; } public TF_DataType dtype { get; } public TensorShape shape { get; } - ITensorOrOperation assign_add(T delta, bool use_locking = false, string name = null, bool read_value = true); - ITensorOrOperation assign(T value, bool use_locking = false, string name = null, bool read_value = true); + Tensor assign_add(T delta, bool use_locking = false, string name = null, bool read_value = true); + Tensor assign(T value, bool use_locking = false, string name = null, bool read_value = true); Tensor AsTensor(); } } diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 34f1d93f..cf9fe2f1 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -335,7 +335,7 @@ namespace Tensorflow /// A `Tensor` that will hold the new value of this variable after /// the assignment has completed. /// - public ITensorOrOperation assign(T value, bool use_locking = false, string name = null, bool read_value = true) + public Tensor assign(T value, bool use_locking = false, string name = null, bool read_value = true) { var assign = gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name); if (read_value) @@ -418,7 +418,7 @@ namespace Tensorflow // name: A name for the operation(optional). // Returns: // A mutable `Tensor`. Has the same type as `ref`. - public ITensorOrOperation assign_add(T value, bool use_locking = false, string name = null, bool read_value = true) + public Tensor assign_add(T value, bool use_locking = false, string name = null, bool read_value = true) { var variable = this; var _op = tf.OpDefLib._apply_op_helper("AssignAdd", name: name, args: new { variable, value, use_locking }); diff --git a/src/TensorFlowNET.Core/Variables/state_ops.cs b/src/TensorFlowNET.Core/Variables/state_ops.cs index e7962ac1..3152686b 100644 --- a/src/TensorFlowNET.Core/Variables/state_ops.cs +++ b/src/TensorFlowNET.Core/Variables/state_ops.cs @@ -95,7 +95,7 @@ namespace Tensorflow // Returns: // Same as "ref". Returned as a convenience for operations that want // to use the new value after the variable has been updated. - public static ITensorOrOperation assign_add(IVariableV1 @ref, + public static Tensor assign_add(IVariableV1 @ref, T value, bool use_locking = false, string name = null)