diff --git a/src/TensorFlowNET.Core/IReturnTensorOrOperation.cs b/src/TensorFlowNET.Core/IReturnTensorOrOperation.cs
new file mode 100644
index 00000000..51c840ac
--- /dev/null
+++ b/src/TensorFlowNET.Core/IReturnTensorOrOperation.cs
@@ -0,0 +1,14 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ ///
+ /// in order to limit function return value
+ /// is Tensor or Operation
+ ///
+ public interface IReturnTensorOrOperation
+ {
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs
index 769c7fb3..4d268c6d 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.cs
@@ -7,7 +7,7 @@ using System.Text;
namespace Tensorflow
{
- public partial class Operation
+ public partial class Operation : IReturnTensorOrOperation
{
private readonly IntPtr _handle; // _c_op in python
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs
index 90b7c2dc..ab7d3304 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs
@@ -12,7 +12,7 @@ namespace Tensorflow
/// A tensor is a generalization of vectors and matrices to potentially higher dimensions.
/// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes.
///
- public partial class Tensor : IDisposable
+ public partial class Tensor : IDisposable, IReturnTensorOrOperation
{
private readonly IntPtr _handle;
@@ -258,7 +258,7 @@ namespace Tensorflow
}
}
- return $"{name} shape=({string.Join(",", shape)}) dtype={dtype.ToString()}";
+ return $"tf.Tensor {name} shape=({string.Join(",", shape)}) dtype={dtype.ToString()}";
}
public void Dispose()
diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs b/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs
index 85fb19f8..dc510fcc 100644
--- a/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs
+++ b/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs
@@ -6,11 +6,21 @@ namespace Tensorflow
{
public partial class RefVariable
{
- public static Tensor operator +(RefVariable t1, int t2)
+ public static Tensor operator +(RefVariable x, int y) => op_helper("add", x, y);
+ public static Tensor operator +(RefVariable x, float y) => op_helper("add", x, y);
+ public static Tensor operator +(RefVariable x, double y) => op_helper("add", x, y);
+
+ public static Tensor operator -(RefVariable x, int y) => op_helper("sub", x, y);
+ public static Tensor operator -(RefVariable x, float y) => op_helper("sub", x, y);
+ public static Tensor operator -(RefVariable x, double y) => op_helper("sub", x, y);
+
+ private static Tensor op_helper(string default_name, RefVariable x, T y)
{
- var tensor1 = t1._AsTensor();
- var tensor2 = ops.convert_to_tensor(t2, tensor1.dtype, "y");
- return gen_math_ops.add(tensor1, tensor2);
+ var tensor1 = x.value();
+ return Python.with(new ops.name_scope("", default_name, new object[] { tensor1, y }), scope => {
+ var tensor2 = ops.convert_to_tensor(y, tensor1.dtype.as_base_dtype(), "y");
+ return gen_math_ops.add(tensor1, tensor2, scope);
+ });
}
}
}
diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs
index 52d5f056..39a8d909 100644
--- a/src/TensorFlowNET.Core/Variables/RefVariable.cs
+++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs
@@ -171,6 +171,34 @@ namespace Tensorflow
return op;
}
+ ///
+ /// Assigns a new value to the variable.
+ ///
+ /// The new value for this variable.
+ /// If `True`, use locking during the assignment.
+ /// The name of the operation to be created
+ ///
+ /// if True, will return something which evaluates to the
+ /// new value of the variable; if False will return the assign op.
+ ///
+ ///
+ /// A `Tensor` that will hold the new value of this variable after
+ /// the assignment has completed.
+ ///
+ public T assign(Tensor value, bool use_locking = false, string name = "", bool read_value = true)
+ where T : IReturnTensorOrOperation
+ {
+ var assign = gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name);
+ if (read_value)
+ return (T)Convert.ChangeType(assign, typeof(T));
+ return (T)Convert.ChangeType(assign.op, typeof(T));
+ }
+
+ public Tensor assign(Tensor value, bool use_locking = false, string name = "")
+ {
+ return gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name);
+ }
+
public override string ToString()
{
return $"tf.Variable '{name}' shape={shape} dtype={dtype}";
diff --git a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs
index 1bb5fc3d..99e7ee20 100644
--- a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs
+++ b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs
@@ -15,7 +15,11 @@ namespace TensorFlowNET.UnitTest
var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer);
var v2 = tf.get_variable("v2", shape: new TensorShape(5), initializer: tf.zeros_initializer);
+ var inc_v1 = v1.assign(v1 + 1.0f);
+ var dec_v2 = v2.assign(v2 - 1.0f);
+ // Add an op to initialize the variables.
+ var init_op = tf.global_variables_initializer();
}
}
}