Browse Source

Implicit between RefVariable and Tensor

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
2418539568
7 changed files with 63 additions and 70 deletions
  1. +8
    -8
      src/TensorFlowNET.Core/APIs/tf.math.cs
  2. +19
    -42
      src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
  3. +18
    -18
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  5. +5
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
  6. +10
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  7. +2
    -1
      test/TensorFlowNET.Examples/LinearRegression.cs

+ 8
- 8
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -6,29 +6,29 @@ namespace Tensorflow
{
public static partial class tf
{
public static unsafe Tensor add(Tensor a, Tensor b)
public static Tensor add(Tensor a, Tensor b)
{
return gen_math_ops.add(a, b);
}

public static unsafe Tensor sub(Tensor a, Tensor b)
public static Tensor sub(Tensor a, Tensor b)
{
return gen_math_ops.sub(a, b);
}

public static unsafe Tensor add(Tensor a, RefVariable b)
public static Tensor multiply(Tensor x, Tensor y)
{
return gen_math_ops.add(a, b);
return gen_math_ops.mul(x, y);
}

public static unsafe Tensor multiply(Tensor x, Tensor y)
public static Tensor pow(Tensor x, Tensor y)
{
return gen_math_ops.mul(x, y);
return gen_math_ops.pow(x, y);
}

public static unsafe Tensor multiply(Tensor x, RefVariable y)
public static Tensor reduce_sum(Tensor input, int? axis = null)
{
return gen_math_ops.mul(x, y);
return gen_math_ops.sum(input, input);
}
}
}

+ 19
- 42
src/TensorFlowNET.Core/Operations/OpDefLibrary.cs View File

@@ -36,50 +36,27 @@ namespace Tensorflow
foreach (var input_arg in op_def.InputArg)
{
var input_name = input_arg.Name;
switch (keywords[input_name])
if (keywords[input_name] is Tensor value)
{
case Tensor value:
if (keywords.ContainsKey(input_name))
{
inputs.Add(value);
}

if (!String.IsNullOrEmpty(input_arg.TypeAttr))
{
attrs[input_arg.TypeAttr] = value.dtype;
}

if (input_arg.IsRef)
{

}
else
{
input_types.Add(value.dtype);
}
break;
case RefVariable value:
if (keywords.ContainsKey(input_name))
{
inputs.Add(value._initial_value);
}

if (!String.IsNullOrEmpty(input_arg.TypeAttr))
{
attrs[input_arg.TypeAttr] = value._initial_value.dtype;
}

if (input_arg.IsRef)
{

}
else
{
input_types.Add(value._initial_value.dtype);
}
break;
if (keywords.ContainsKey(input_name))
{
inputs.Add(value);
}

if (!String.IsNullOrEmpty(input_arg.TypeAttr))
{
attrs[input_arg.TypeAttr] = value.dtype;
}

if (input_arg.IsRef)
{

}
else
{
input_types.Add(value.dtype);
}
}

}

// Process remaining attrs


+ 18
- 18
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -20,59 +20,59 @@ namespace Tensorflow
return new Tensor(_op, 0, _op.OutputType(0));
}

public static Tensor add(Tensor a, RefVariable b)
public static Tensor sub(Tensor x, Tensor y)
{
var keywords = new Dictionary<string, object>();
keywords.Add("x", a);
keywords.Add("y", b);
keywords.Add("x", x);
keywords.Add("y", y);

var _op = _op_def_lib._apply_op_helper("Add", name: "add", keywords: keywords);
var _op = _op_def_lib._apply_op_helper("Sub", name: "sub", keywords: keywords);

return new Tensor(_op, 0, _op.OutputType(0));
}

public static Tensor sub(Tensor x, Tensor y)
public static Tensor mul(Tensor x, Tensor y)
{
var keywords = new Dictionary<string, object>();
keywords.Add("x", x);
keywords.Add("y", y);

var _op = _op_def_lib._apply_op_helper("Sub", name: "sub", keywords: keywords);
var _op = _op_def_lib._apply_op_helper("Mul", name: "mul", keywords: keywords);

return new Tensor(_op, 0, _op.OutputType(0));
}

public static Tensor mul(Tensor x, Tensor y)
public static Tensor mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false)
{
var keywords = new Dictionary<string, object>();
keywords.Add("x", x);
keywords.Add("y", y);
keywords.Add("a", a);
keywords.Add("b", b);
keywords.Add("transpose_a", transpose_a);
keywords.Add("transpose_b", transpose_b);

var _op = _op_def_lib._apply_op_helper("Mul", name: "mul", keywords: keywords);
var _op = _op_def_lib._apply_op_helper("MatMul", name: "MatMul", keywords: keywords);

return new Tensor(_op, 0, _op.OutputType(0));
}

public static Tensor mul(Tensor x, RefVariable y)
public static Tensor pow(Tensor x, Tensor y)
{
var keywords = new Dictionary<string, object>();
keywords.Add("x", x);
keywords.Add("y", y);

var _op = _op_def_lib._apply_op_helper("Mul", name: "mul", keywords: keywords);
var _op = _op_def_lib._apply_op_helper("Pow", name: "Pow", keywords: keywords);

return new Tensor(_op, 0, _op.OutputType(0));
}

public static Tensor mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false)
public static Tensor sum(Tensor x, Tensor y)
{
var keywords = new Dictionary<string, object>();
keywords.Add("a", a);
keywords.Add("b", b);
keywords.Add("transpose_a", transpose_a);
keywords.Add("transpose_b", transpose_b);
keywords.Add("x", x);
keywords.Add("y", y);

var _op = _op_def_lib._apply_op_helper("MatMul", name: "MatMul", keywords: keywords);
var _op = _op_def_lib._apply_op_helper("Pow", name: "Pow", keywords: keywords);

return new Tensor(_op, 0, _op.OutputType(0));
}


+ 1
- 1
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -47,7 +47,7 @@ namespace Tensorflow
return result;
}

private unsafe object _run(Tensor fetches, Dictionary<Tensor, NDArray> feed_dict = null)
private object _run(Tensor fetches, Dictionary<Tensor, NDArray> feed_dict = null)
{
var feed_dict_tensor = new Dictionary<Tensor, NDArray>();



+ 5
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs View File

@@ -11,6 +11,11 @@ namespace Tensorflow
return gen_math_ops.add(t1, t2);
}

public static Tensor operator -(Tensor t1, Tensor t2)
{
return gen_math_ops.sub(t1, t2);
}

public static Tensor operator *(Tensor t1, Tensor t2)
{
return gen_math_ops.mul(t1, t2);


+ 10
- 0
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -211,6 +211,11 @@ namespace Tensorflow
status.Dispose();
}

public static implicit operator Tensor(int scalar)
{
return new Tensor(scalar);
}

public static implicit operator IntPtr(Tensor tensor)
{
return tensor._handle;
@@ -220,5 +225,10 @@ namespace Tensorflow
{
return new Tensor(handle);
}

public static implicit operator Tensor(RefVariable var)
{
return var._initial_value;
}
}
}

+ 2
- 1
test/TensorFlowNET.Examples/LinearRegression.cs View File

@@ -40,7 +40,8 @@ namespace TensorFlowNET.Examples
var pred = tf.add(part1, b);

// Mean squared error
var cost = tf.reduce_sum(tf.pow(pred - Y, 2)) / (2 * n_samples);
var pow = tf.pow(pred - Y, 2);
//var cost = tf.reduce_sum(pow) / (2 * n_samples);
}
}
}

Loading…
Cancel
Save