diff --git a/src/TensorFlowNET.Core/Gradients/c_api.gradient.cs b/src/TensorFlowNET.Core/Gradients/c_api.gradient.cs
index e586846b..9765743a 100644
--- a/src/TensorFlowNET.Core/Gradients/c_api.gradient.cs
+++ b/src/TensorFlowNET.Core/Gradients/c_api.gradient.cs
@@ -24,6 +24,6 @@ namespace Tensorflow
/// TF_Output*
[DllImport(TensorFlowLibName)]
public static extern void TF_AddGradientsWithPrefix(IntPtr g, string prefix, TF_Output[] y, int ny,
- TF_Output[] x, int nx, TF_Output[] dx, IntPtr status, ref IntPtr dy);
+ TF_Output[] x, int nx, TF_Output[] dx, IntPtr status, IntPtr[] dy);
}
}
diff --git a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs
index 2287e1eb..456158e5 100644
--- a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs
+++ b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs
@@ -37,6 +37,8 @@ namespace Tensorflow
// ancestor graphs. This is necessary for correctly handling captured values.
var curr_graph = src_graph;
+ if (stop_gradients == null)
+ stop_gradients = new Tensor[0];
if (grad_ys == null)
grad_ys = new Tensor[ys.Length];
@@ -45,7 +47,7 @@ namespace Tensorflow
all.AddRange(xs);
all.AddRange(stop_gradients);
all.AddRange(grad_ys);
-
+
// Iterate over the collected ops.
/**
* grads: op => list of gradients received on each output endpoint of the
diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.py.cs b/src/TensorFlowNET.Core/Gradients/math_grad.py.cs
index c2758447..d6b78be6 100644
--- a/src/TensorFlowNET.Core/Gradients/math_grad.py.cs
+++ b/src/TensorFlowNET.Core/Gradients/math_grad.py.cs
@@ -16,5 +16,17 @@ namespace Tensorflow
return (grad, grad);
}
+
+ public static (Tensor, Tensor) _RealDivGrad(Operation op, Tensor grad)
+ {
+ var x = op.inputs[0];
+ var y = op.inputs[1];
+
+ var sx = array_ops.shape(x);
+ var sy = array_ops.shape(y);
+ // rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
+
+ return (grad, grad);
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
index 8c28c58b..9acc7fe8 100644
--- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
@@ -59,5 +59,10 @@ namespace Tensorflow
return _op.outputs[0];
}
+
+ public static (Tensor, Tensor) broadcast_gradient_args(Tensor s0, Tensor s1, string name = "")
+ {
+ return (null, null);
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs
index 4fb83fd0..65f3742d 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs
@@ -42,7 +42,7 @@ namespace Tensorflow
{
get
{
- var dims = new long[rank];
+ var dims = new long[rank < 0 ? 0 : rank];
if (_handle == IntPtr.Zero)
{
diff --git a/src/TensorFlowNET.Core/Train/Optimizer.cs b/src/TensorFlowNET.Core/Train/Optimizer.cs
index 244eda50..3533cd88 100644
--- a/src/TensorFlowNET.Core/Train/Optimizer.cs
+++ b/src/TensorFlowNET.Core/Train/Optimizer.cs
@@ -14,6 +14,11 @@ namespace Tensorflow
///
public abstract class Optimizer
{
+ // Values for gate_gradients.
+ public static int GATE_NONE = 0;
+ public static int GATE_OP = 1;
+ public static int GATE_GRAPH = 2;
+
public string Name { get; set; }
public double LearningRate { get; set; }
public Tensor LearningRateTensor { get; set; }
@@ -87,11 +92,15 @@ namespace Tensorflow
var processors = var_list.Select(v => optimizer._get_processor(v)).ToList();
var var_refs = processors.Select(x => x.target()).ToArray();
- gradients_impl.gradients(new Tensor[] { loss }, var_refs, grad_ys: grad_loss,
+ var grads = gradients_impl.gradients(new Tensor[] { loss }, var_refs, grad_ys: grad_loss,
gate_gradients: (gate_gradients == GateGradientType.GATE_OP),
aggregation_method: aggregation_method,
colocate_gradients_with_ops: colocate_gradients_with_ops);
+ //if ((int)gate_gradients == Optimizer.GATE_GRAPH)
+ //grads = control_flow_ops.tuple(grads);
+
+
return null;
}
}
diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs
index f22f550b..84c7deff 100644
--- a/src/TensorFlowNET.Core/ops.py.cs
+++ b/src/TensorFlowNET.Core/ops.py.cs
@@ -8,6 +8,7 @@ using node_def_pb2 = Tensorflow;
using Google.Protobuf;
using System.Linq;
using NumSharp.Core;
+using System.ComponentModel;
namespace Tensorflow
{
@@ -285,7 +286,20 @@ namespace Tensorflow
return (oper, out_grads) =>
{
- return math_grad._AddGrad(op, out_grads);
+ switch (oper.type)
+ {
+ case "Add":
+ return math_grad._AddGrad(op, out_grads);
+ case "RealDiv":
+ return math_grad._RealDivGrad(op, out_grads);
+ default:
+ throw new NotImplementedException("get_gradient_function");
+ }
+ /*var result = typeof(math_grad).GetMethod($"_{op.type}Grad").Invoke(null, new object[] { op, out_grads });
+ var p1 = result.GetType().GetProperty("Item1");
+ var p2 = result.GetType().GetProperty("Item2");
+
+ return (p1.GetValue(result, null) as Tensor, p2.GetValue(result, null) as Tensor);*/
};
}
}
diff --git a/test/TensorFlowNET.Examples/LinearRegression.cs b/test/TensorFlowNET.Examples/LinearRegression.cs
index 6e9a2c0a..91294c07 100644
--- a/test/TensorFlowNET.Examples/LinearRegression.cs
+++ b/test/TensorFlowNET.Examples/LinearRegression.cs
@@ -36,8 +36,8 @@ namespace TensorFlowNET.Examples
var W = tf.Variable(rng.randn(), name: "weight");
var b = tf.Variable(rng.randn(), name: "bias");
- var part1 = tf.multiply(X, W);
- var pred = tf.add(part1, b);
+ var mul = tf.multiply(X, W);
+ var pred = tf.add(mul, b);
// Mean squared error
var sub = pred - Y;
diff --git a/test/TensorFlowNET.UnitTest/CApiGradientsTest.cs b/test/TensorFlowNET.UnitTest/CApiGradientsTest.cs
index ccdf78f6..c5a9095b 100644
--- a/test/TensorFlowNET.UnitTest/CApiGradientsTest.cs
+++ b/test/TensorFlowNET.UnitTest/CApiGradientsTest.cs
@@ -111,12 +111,11 @@ namespace TensorFlowNET.UnitTest
var grad_inputs_op = FloatConst2x2(graph_, s_, grad_inputs_val, "GradInputs");
grad_inputs[0] = new TF_Output(grad_inputs_op, 0);
- IntPtr handle = IntPtr.Zero;
+ IntPtr[] handles = new IntPtr[2] { IntPtr.Zero, IntPtr.Zero };
c_api.TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs,
- ninputs, grad_inputs, s_, ref handle);
-
- grad_outputs[0] = Marshal.PtrToStructure(handle);
- var op = new Operation(handle);
+ ninputs, grad_inputs, s_, handles);
+
+ var op = new Operation(handles[0]);
}
else
{