diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index 97a95e95..66e1ba00 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -390,7 +390,7 @@ namespace Tensorflow => x / ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"); public Tensor pow(T1 x, T2 y, string name = "pow") - => gen_math_ops.pow(x, y, name: name); + => math_ops.pow(x, y, name: name); /// /// Divides `x / y` elementwise, rounding toward the most negative integer. diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.cs index 89b23c62..b3b481a1 100644 --- a/src/TensorFlowNET.Core/Eager/EagerTensor.cs +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.cs @@ -53,6 +53,9 @@ namespace Tensorflow.Eager public static string GetFormattedString(TF_DataType dtype, NDArray nd) { + if (nd.size == 0) + return "[]"; + switch (dtype) { case TF_DataType.TF_STRING: diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs index a92d5bda..46c3fa96 100644 --- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs +++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs @@ -375,6 +375,9 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern void TFE_TapeWatch(IntPtr tape, IntPtr tensor); + [DllImport(TensorFlowLibName)] + public static extern void TFE_TapeVariableAccessed(IntPtr variable); + [DllImport(TensorFlowLibName)] public static extern IntPtr TFE_TapeGradient(IntPtr tape, IntPtr[] target, int target_size, diff --git a/src/TensorFlowNET.Core/Gradients/GradientActor.cs b/src/TensorFlowNET.Core/Gradients/GradientActor.cs index e6dbe92a..a6000734 100644 --- a/src/TensorFlowNET.Core/Gradients/GradientActor.cs +++ b/src/TensorFlowNET.Core/Gradients/GradientActor.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; using Tensorflow.Eager; using static Tensorflow.Binding; @@ -65,7 +66,7 @@ namespace Tensorflow.Gradients _tape.watch(x as EagerTensor); } - public Tensor gradient(Tensor target, Tensor sources) + public Tensor gradient(Tensor target, Tensor source) { if(_recording) { @@ -76,15 +77,33 @@ namespace Tensorflow.Gradients using var status = new Status(); var et = c_api.TFE_TapeGradient(_tape, new [] { (target as EagerTensor).EagerTensorHandle }, 1, - new [] { (sources as EagerTensor).EagerTensorHandle }, 1, + new [] { (source as EagerTensor).EagerTensorHandle }, 1, status); status.Check(true); return new EagerTensor(et); } + public Tensor gradient(Tensor target, ResourceVariable[] sources) + { + if (_recording) + { + if (!_persistent) + _pop_tape(); + } + + using var status = new Status(); + EagerTensorHandle et = c_api.TFE_TapeGradient(_tape, + new[] { (target as EagerTensor).EagerTensorHandle }, 1, + sources.Select(x => (x.handle as EagerTensor).EagerTensorHandle).ToArray(), sources.Length, + status); + status.Check(true); + return et; + } + public void Dispose() { - + if (_recording) + _pop_tape(); } } } diff --git a/src/TensorFlowNET.Core/Gradients/Tape.cs b/src/TensorFlowNET.Core/Gradients/Tape.cs index 8bcf7f5f..00162a8f 100644 --- a/src/TensorFlowNET.Core/Gradients/Tape.cs +++ b/src/TensorFlowNET.Core/Gradients/Tape.cs @@ -25,6 +25,11 @@ namespace Tensorflow.Gradients c_api.TFE_TapeSetRemove(tape); } + public static void variable_accessed(ResourceVariable variable) + { + c_api.TFE_TapeVariableAccessed(variable.handle as EagerTensor); + } + public static bool IsDtypeTrainable(DataType dtype) { switch (dtype) diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 3fe85dbe..c728388c 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -220,6 +220,18 @@ namespace Tensorflow /// public static Tensor identity(Tensor input, string name = null) { + if (tf.context.executing_eagerly()) + { + using var status = new Status(); + EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Identity", name, new IntPtr[] + { + input as EagerTensor + }, 1, null, status); + status.Check(true); + return tensor; + } + var _op = _op_def_lib._apply_op_helper("Identity", name, new { input }); return _op.output; @@ -258,14 +270,14 @@ namespace Tensorflow if (tf.context.executing_eagerly()) { using var status = new Status(); - var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, "Fill", name, new IntPtr[] { dims as EagerTensor, value as EagerTensor }, 2, null, status); status.Check(true); - return new EagerTensor(tensor); + return tensor; } var _op = _op_def_lib._apply_op_helper("Fill", name, new { dims, value }); @@ -281,6 +293,18 @@ namespace Tensorflow /// A tuple of `Tensor` objects (r0, r1). public static (Tensor, Tensor) broadcast_gradient_args(Tensor s0, Tensor s1, string name = "") { + if (tf.context.executing_eagerly()) + { + using var status = new Status(); + var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "BroadcastGradientArgs", name, new IntPtr[] + { + s0 as EagerTensor, + s1 as EagerTensor + }, 2, null, status); + status.Check(true); + } + var _op = _op_def_lib._apply_op_helper("BroadcastGradientArgs", name, new { s0, s1 }); return (_op.outputs[0], _op.outputs[1]); @@ -371,10 +395,19 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { - var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, - "Shape", name, null, - input, "out_type", out_type); - return _result; + using var status = new Status(); + EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Shape", name, new IntPtr[] + { + input as EagerTensor, + }, 1, + op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] + { + "out_type", out_type + }, status), + status); + status.Check(true); + return tensor; } var _op = _op_def_lib._apply_op_helper("Shape", name, new { input, out_type }); @@ -455,12 +488,26 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { - var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, - "StridedSlice", name, null, - input, begin, end, strides, "begin_mask", begin_mask, - "end_mask", end_mask, "ellipsis_mask", ellipsis_mask, - "new_axis_mask", new_axis_mask, "shrink_axis_mask", shrink_axis_mask); - return _result; + using var status = new Status(); + EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "StridedSlice", name, new IntPtr[] + { + input as EagerTensor, + begin as EagerTensor, + end as EagerTensor, + strides as EagerTensor, + }, 4, + op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] + { + "begin_mask", begin_mask, + "end_mask", end_mask, + "ellipsis_mask", ellipsis_mask, + "new_axis_mask", new_axis_mask, + "shrink_axis_mask", shrink_axis_mask + }, status), + status); + status.Check(true); + return tensor; } var _op = _op_def_lib._apply_op_helper("StridedSlice", name, new diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 5f882acf..2afa02c6 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -173,10 +173,20 @@ namespace Tensorflow { try { - var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, - "Prod", name, null, - input, axis, "keep_dims", keep_dims); - return _result; + using var status = new Status(); + EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Prod", name, new IntPtr[] + { + input as EagerTensor, + axis as EagerTensor + }, 2, + op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] + { + "keep_dims", keep_dims + }, status), + status); + status.Check(true); + return tensor; } catch (Exception) { @@ -236,10 +246,15 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { - var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, - "Add", name, null, - x, y); - return _result; + using var status = new Status(); + EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Add", name, new IntPtr[] + { + x as EagerTensor, + y as EagerTensor + }, 2, null, status); + status.Check(true); + return tensor; } var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y }); @@ -647,10 +662,14 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { - var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, - "Sqrt", name, null, - x); - return _result; + using var status = new Status(); + EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Sqrt", name, new IntPtr[] + { + x as EagerTensor, + }, 1, null, status); + status.Check(true); + return tensor; } var _op = _op_def_lib._apply_op_helper("Sqrt", name, args: new { x }); @@ -682,10 +701,15 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { - var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, - "Sub", name, null, - x, y); - return _result; + using var status = new Status(); + EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Sub", name, new IntPtr[] + { + x as EagerTensor, + y as EagerTensor + }, 2, null, status); + status.Check(true); + return tensor; } var _op = _op_def_lib._apply_op_helper("Sub", name, args: new { x, y }); @@ -704,10 +728,15 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { - var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, - "Equal", name, null, - x, y); - return _result; + using var status = new Status(); + EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Equal", name, new IntPtr[] + { + x as EagerTensor, + y as EagerTensor + }, 2, null, status); + status.Check(true); + return tensor; } var _op = _op_def_lib._apply_op_helper("Equal", name, args: new { x, y }); @@ -727,10 +756,15 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { - var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, - "NotEqual", name, null, - x, y); - return _result; + using var status = new Status(); + EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "NotEqual", name, new IntPtr[] + { + x as EagerTensor, + y as EagerTensor + }, 2, null, status); + status.Check(true); + return tensor; } var _op = _op_def_lib._apply_op_helper("NotEqual", name, args: new { x, y }); @@ -742,10 +776,15 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { - var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, - "Atan2", name, null, - y, x); - return _result; + using var status = new Status(); + EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Atan2", name, new IntPtr[] + { + y as EagerTensor, + x as EagerTensor + }, 2, null, status); + status.Check(true); + return tensor; } var _op = _op_def_lib._apply_op_helper("Atan2", name, args: new { y, x }); @@ -757,14 +796,14 @@ namespace Tensorflow if (tf.context.executing_eagerly()) { using var status = new Status(); - var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, "Mul", name, new IntPtr[] { x as EagerTensor, y as EagerTensor }, 2, null, status); status.Check(true); - return new EagerTensor(_result); + return tensor; } var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y }); @@ -776,10 +815,15 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { - var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, - "Mul", name, null, - x, y); - return _result; + using var status = new Status(); + EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Mul", name, new IntPtr[] + { + x as EagerTensor, + y as EagerTensor, + }, 1, null, status); + status.Check(true); + return tensor; } var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y }); @@ -832,8 +876,15 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { - var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, "", "FloorDiv", name, null, x, y); - return _result; + using var status = new Status(); + EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "FloorDiv", name, new IntPtr[] + { + x as EagerTensor, + y as EagerTensor + }, 2, null, status); + status.Check(true); + return tensor; } var _op = _op_def_lib._apply_op_helper("FloorDiv", name, args: new { x, y }); @@ -864,10 +915,8 @@ namespace Tensorflow }, 2, op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] { - "transpose_a", - transpose_a, - "transpose_b", - transpose_b + "transpose_a", transpose_a, + "transpose_b", transpose_b }, status), status); status.Check(true); @@ -965,6 +1014,19 @@ namespace Tensorflow public static Tensor pow(Tx x, Ty y, string name = null) { + if (tf.context.executing_eagerly()) + { + using var status = new Status(); + EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Pow", name, new IntPtr[] + { + x as EagerTensor, + y as EagerTensor + }, 2, null, status); + status.Check(true); + return tensor; + } + var _op = _op_def_lib._apply_op_helper("Pow", name, args: new { x, y }); return _op.outputs[0]; diff --git a/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs index 8ce319eb..f91177d1 100644 --- a/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs @@ -115,13 +115,13 @@ namespace Tensorflow if (tf.context.executing_eagerly()) { using var status = new Status(); - var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, "ReadVariableOp", name, new IntPtr[] { resource as EagerTensor }, 1, op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] { "dtype", dtype }, status), status); status.Check(true); - return new EagerTensor(tensor); + return tensor; } var _op = _op_def_lib._apply_op_helper("ReadVariableOp", name, new diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index ce089032..1e4aefce 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -17,6 +17,7 @@ using NumSharp; using System; using System.Collections.Generic; +using Tensorflow.Eager; using Tensorflow.Framework; using static Tensorflow.Binding; @@ -540,6 +541,11 @@ namespace Tensorflow } else { + if(x is EagerTensor) + { + return constant_op.constant(np.arange(x.shape.Rank)); + } + var rank = array_ops.rank(x); return range(0, rank, 1); } @@ -588,7 +594,14 @@ namespace Tensorflow => gen_math_ops.rsqrt(x, name: name); public static Tensor pow(Tx x, Ty y, string name = null) - => gen_math_ops.pow(x, y, name: name); + => tf_with(ops.name_scope(name, "Pow", new { x, y }), scope => + { + name = scope; + var x_tensor = ops.convert_to_tensor(x, name: "x"); + var y_tensor = ops.convert_to_tensor(y, name: "y", dtype: x_tensor.dtype.as_base_dtype()); + + return gen_math_ops.pow(x_tensor, y_tensor, name: name); + }); public static Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range") { diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs index e9ecb79a..fc97895d 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs @@ -54,7 +54,7 @@ namespace Tensorflow #else #region Compute - + public static Tensor operator +(Tensor lhs, ResourceVariable rhs) => BinaryOpWrapper("add", lhs, rhs); public static Tensor operator +(Tensor lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); public static Tensor operator +(Tensor lhs, NDArray rhs) => BinaryOpWrapper("add", lhs, rhs); public static Tensor operator +(NDArray lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs index 84ba7c04..1a02e2c5 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs @@ -43,7 +43,7 @@ namespace Tensorflow { //T can only be unmanaged, I believe it is safe to say that MemoryCopy is valid for all cases this method can be called. var src = (T*)buffer; - len *= ((long)itemsize); + len *= (long)itemsize; System.Buffer.MemoryCopy(src, dst, len, len); } } diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index 73e3365a..3882646c 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -113,6 +113,21 @@ namespace Tensorflow private static EagerTensor convert_to_eager_tensor(object value, Context ctx, TF_DataType dtype = TF_DataType.DtInvalid) { + // convert data type + if (dtype != TF_DataType.DtInvalid && + value.GetType().Name != "NDArray" && + dtypes.as_base_dtype(dtype) != dtypes.as_dtype(value.GetType())) + { + switch (dtype) + { + case TF_DataType.TF_FLOAT: + value = Convert.ToSingle(value); + break; + default: + break; + } + } + switch (value) { case NDArray val: diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index 6cc8972f..1c0307a2 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Gradients; using static Tensorflow.Binding; namespace Tensorflow @@ -65,6 +66,7 @@ namespace Tensorflow protected Tensor _read_variable_op() { + variable_accessed(this); var result = gen_resource_variable_ops.read_variable_op(_handle, _dtype); // _maybe_set_handle_data(_dtype, _handle, result); return result; @@ -82,12 +84,26 @@ namespace Tensorflow void variable_accessed(BaseResourceVariable variable) { if (variable.trainable) - ; // tape.variable_accessed(variable) + Tape.variable_accessed(variable as ResourceVariable); } + /// + /// Constructs an op which reads the value of this variable. + /// + /// Should be used when there are multiple reads, or when it is desirable to + /// read the value only after some condition is true. + /// + /// + Tensor read_value() + => tf_with(ops.name_scope("Read"), delegate + { + var value = _read_variable_op(); + return array_ops.identity(value); + }); + public override string ToString() => $"tf.Variable '{name}' shape={shape} dtype={dtype.as_numpy_name()}, numpy={numpy()}"; - public NDArray numpy() => _read_variable_op().numpy(); + public NDArray numpy() => read_value().numpy(); } } diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs index d6eafbc1..80aab711 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using NumSharp; using System; using static Tensorflow.Binding; @@ -31,6 +32,7 @@ namespace Tensorflow public static Tensor operator -(ResourceVariable x, Tensor y) => op_helper("sub", x, y); public static Tensor operator *(ResourceVariable x, ResourceVariable y) => gen_math_ops.mul(x, y); + public static Tensor operator *(ResourceVariable x, NDArray y) => op_helper("mul", x, y); public static Tensor operator <(ResourceVariable x, Tensor y) => gen_math_ops.less(x.value(), y); @@ -53,6 +55,9 @@ namespace Tensorflow case "sub": result = gen_math_ops.sub(xVal, yTensor, name); break; + case "mul": + result = gen_math_ops.mul(xVal, yTensor, name: name); + break; default: throw new NotImplementedException(""); } diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index f73278c4..644a8dc5 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -464,7 +464,7 @@ namespace Tensorflow case RefVariable varVal: return varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref); case ResourceVariable varVal: - return null; + return varVal.value(); case TensorShape ts: return constant_op.constant(ts.dims, dtype: dtype, name: name); case int[] dims: