From 98fb777df95e887a0ee51c32a5da7436aa08f096 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Mon, 7 Jan 2019 03:16:39 -0600 Subject: [PATCH] LinearRegression --- src/TensorFlowNET.Core/APIs/tf.math.cs | 5 +++++ src/TensorFlowNET.Core/Operations/gen_math_ops.cs | 11 +++++++++++ test/TensorFlowNET.Examples/LinearRegression.cs | 6 +++++- 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index 5f019bb4..24d382ec 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -11,6 +11,11 @@ namespace Tensorflow return gen_math_ops.add(a, b); } + public static unsafe Tensor add(Tensor a, RefVariable b) + { + return gen_math_ops.add(a, b); + } + public static unsafe Tensor multiply(Tensor x, Tensor y) { return gen_math_ops.mul(x, y); diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 27e555bc..f1706311 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -20,6 +20,17 @@ namespace Tensorflow return new Tensor(_op, 0, _op.OutputType(0)); } + public static Tensor add(Tensor a, RefVariable b) + { + var keywords = new Dictionary(); + keywords.Add("x", a); + keywords.Add("y", b); + + var _op = _op_def_lib._apply_op_helper("Add", name: "add", keywords: keywords); + + return new Tensor(_op, 0, _op.OutputType(0)); + } + public static Tensor mul(Tensor x, Tensor y) { var keywords = new Dictionary(); diff --git a/test/TensorFlowNET.Examples/LinearRegression.cs b/test/TensorFlowNET.Examples/LinearRegression.cs index 7ecf02d1..633d7b82 100644 --- a/test/TensorFlowNET.Examples/LinearRegression.cs +++ b/test/TensorFlowNET.Examples/LinearRegression.cs @@ -36,7 +36,11 @@ namespace TensorFlowNET.Examples var W = tf.Variable(rng.randn(), name: "weight"); var b = tf.Variable(rng.randn(), name: "bias"); - var aa = tf.multiply(X, W); + var part1 = tf.multiply(X, W); + var pred = tf.add(part1, b); + + // Mean squared error + var cost = tf.reduce_sum(tf.pow(pred - Y, 2)) / (2 * n_samples); } } }