Browse Source

LinearRegression

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
98fb777df9
3 changed files with 21 additions and 1 deletions
  1. +5
    -0
      src/TensorFlowNET.Core/APIs/tf.math.cs
  2. +11
    -0
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  3. +5
    -1
      test/TensorFlowNET.Examples/LinearRegression.cs

+ 5
- 0
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -11,6 +11,11 @@ namespace Tensorflow
return gen_math_ops.add(a, b);
}

public static unsafe Tensor add(Tensor a, RefVariable b)
{
return gen_math_ops.add(a, b);
}

public static unsafe Tensor multiply(Tensor x, Tensor y)
{
return gen_math_ops.mul(x, y);


+ 11
- 0
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -20,6 +20,17 @@ namespace Tensorflow
return new Tensor(_op, 0, _op.OutputType(0));
}

public static Tensor add(Tensor a, RefVariable b)
{
var keywords = new Dictionary<string, object>();
keywords.Add("x", a);
keywords.Add("y", b);

var _op = _op_def_lib._apply_op_helper("Add", name: "add", keywords: keywords);

return new Tensor(_op, 0, _op.OutputType(0));
}

public static Tensor mul(Tensor x, Tensor y)
{
var keywords = new Dictionary<string, object>();


+ 5
- 1
test/TensorFlowNET.Examples/LinearRegression.cs View File

@@ -36,7 +36,11 @@ namespace TensorFlowNET.Examples
var W = tf.Variable(rng.randn<double>(), name: "weight");
var b = tf.Variable(rng.randn<double>(), name: "bias");

var aa = tf.multiply(X, W);
var part1 = tf.multiply(X, W);
var pred = tf.add(part1, b);

// Mean squared error
var cost = tf.reduce_sum(tf.pow(pred - Y, 2)) / (2 * n_samples);
}
}
}

Loading…
Cancel
Save