Browse Source

reduce_sum in progress

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
126ed2766f
3 changed files with 53 additions and 4 deletions
  1. +14
    -1
      src/TensorFlowNET.Core/APIs/tf.math.cs
  2. +10
    -0
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  3. +29
    -3
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs

+ 14
- 1
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -26,9 +26,22 @@ namespace Tensorflow
return gen_math_ops.pow(x, y);
}

/// <summary>
/// Computes the sum of elements across dimensions of a tensor.
/// </summary>
/// <param name="input"></param>
/// <param name="axis"></param>
/// <returns></returns>
public static Tensor reduce_sum(Tensor input, int[] axis = null)
{
return gen_math_ops.sum(input, axis);
Tensor rank;
using (var namescop = new ops.name_scope<Tensor>("", "Rank", new List<Tensor> { input }))
{
string name = namescop;
rank = gen_array_ops.rank(input, name);
}
var s = gen_math_ops.sum(input, rank);
return gen_math_ops.range(0, s);
}
}
}

+ 10
- 0
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -41,5 +41,15 @@ namespace Tensorflow

return _op.outputs[0];
}

public static Tensor rank(Tensor input, string name = "")
{
var keywords = new Dictionary<string, object>();
keywords.Add("input", input);

var _op = _op_def_lib._apply_op_helper("Rank", name: name, keywords: keywords);

return _op.outputs[0];
}
}
}

+ 29
- 3
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -78,17 +78,43 @@ namespace Tensorflow
return new Tensor(_op, 0, _op.OutputType(0));
}

public static Tensor sum(Tensor input, int[] axis = null)
public static Tensor sum(Tensor input, Tensor axis = null)
{
if(axis == null) axis = new int[0];
var keywords = new Dictionary<string, object>();
keywords.Add("input", input);
keywords.Add("reduction_indices", constant_op.Constant(axis));
keywords.Add("reduction_indices", axis);
keywords.Add("keep_dims", false);

var _op = _op_def_lib._apply_op_helper("Sum", keywords: keywords);

return new Tensor(_op, 0, _op.OutputType(0));
}

/// <summary>
/// Creates a sequence of numbers.
/// </summary>
/// <param name="start"></param>
/// <param name="limit"></param>
/// <param name="delta"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor range(int start, Tensor limit, int delta = 1)
{
using (var namescope = new ops.name_scope<Tensor>("", "Range", new List<Tensor> { start, limit, delta }))
{
var start1 = ops.convert_to_tensor(start, "start");
var limit1 = ops.convert_to_tensor(limit, "limit");
var delta1 = ops.convert_to_tensor(delta, "delta");

var keywords = new Dictionary<string, object>();
keywords.Add("start", start1);
keywords.Add("limit", limit1);
keywords.Add("delta", delta1);

var _op = _op_def_lib._apply_op_helper("Range", namescope, keywords);

return _op.outputs[0];
}
}
}
}

Loading…
Cancel
Save