Browse Source

k-means input has only 0 dims for 'cond/strided_slice'

tags/v0.9
Oceania2018 6 years ago
parent
commit
b3a2422e4d
11 changed files with 155 additions and 15 deletions
  1. +50
    -8
      src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs
  2. +8
    -0
      src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
  3. +8
    -0
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
  4. +7
    -0
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  5. +13
    -0
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  6. +19
    -1
      src/TensorFlowNET.Core/Operations/math_ops.cs
  7. +22
    -0
      src/TensorFlowNET.Core/Operations/nn_impl.py.cs
  8. +2
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
  9. +5
    -0
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  10. +2
    -2
      src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs
  11. +19
    -4
      src/TensorFlowNET.Core/Variables/state_ops.cs

+ 50
- 8
src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs View File

@@ -52,17 +52,24 @@ namespace Tensorflow.Clustering
_num_data = math_ops.add_n(_inputs.Select(i => array_ops.shape(i)[0]).ToArray());
}

private Operation[] _initialize()
private Tensor[] _initialize()
{
with(ops.control_dependencies(new Operation[]
return with(ops.control_dependencies(new Operation[]
{
check_ops.assert_positive(_num_remaining)
}), delegate
{
// var num_now_remaining = _add_new_centers();
var num_now_remaining = _add_new_centers();
return control_flow_ops.cond(math_ops.equal(num_now_remaining, 0),
() =>
{
return new Tensor[] { state_ops.assign(_cluster_centers_initialized, true) };
},
() =>
{
return new Tensor[] { control_flow_ops.no_op().output[0] };
});
});

throw new NotImplementedException("_InitializeClustersOpFactory _initialize");
}

public Tensor[] op()
@@ -71,14 +78,49 @@ namespace Tensorflow.Clustering
() =>
{
var op = check_ops.assert_equal(_cluster_centers_initialized, true);
return new Operation[] { op };
return new Tensor[] { op.output[0] };
},
_initialize);
}

/*private int _add_new_centers()
private Tensor _add_new_centers()
{
// Adds some centers and returns the number of centers remaining.
var new_centers = _choose_initial_centers();
}*/
if (_distance_metric == KMeans.COSINE_DISTANCE)
new_centers = nn_impl.l2_normalize(new_centers[0], axis: 1);

// If cluster_centers is empty, it doesn't have the right shape for concat.
var all_centers = control_flow_ops.cond(math_ops.equal(_num_selected, 0),
() => new Tensor[] { new_centers },
() => new Tensor[] { gen_array_ops.concat(new Tensor[] { _cluster_centers, new_centers }, 0) });

var a = state_ops.assign(_cluster_centers, all_centers, validate_shape: false);

return _num_clusters - array_ops.shape(a)[0];
}

private Tensor _choose_initial_centers()
{
return _greedy_batch_sampler()[0];
}

private Tensor[] _greedy_batch_sampler()
{
return control_flow_ops.cond(_num_data <= _num_remaining,
() =>
{
return new Tensor[] { gen_array_ops.concat(_inputs, 0) };
},
() =>
{
return new Tensor[] { _random() };
});
}

private Tensor _random()
{
throw new NotImplementedException("");
}
}
}

+ 8
- 0
src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs View File

@@ -75,6 +75,8 @@ namespace Tensorflow.Operations
{
case Tensor[] results:
return (original_result, results);
case Operation[] results:
return (original_result, new Tensor[] { _BuildCondTensor (results) });
case float[] fv:
var result = ops.convert_to_tensor(fv[0]);
return (original_result, new Tensor[] { result });
@@ -82,5 +84,11 @@ namespace Tensorflow.Operations
return (original_result, new Tensor[0]);
}
}

private Tensor _BuildCondTensor(Operation[] v)
{
// Use pivot as the proxy for this op.
return control_flow_ops.with_dependencies(v, _pivot);
}
}
}

+ 8
- 0
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -64,6 +64,14 @@ namespace Tensorflow
});
}

/// <summary>
/// Does nothing. Only useful as a placeholder for control edges.
/// </summary>
/// <param name="name"></param>
/// <returns></returns>
public static Operation no_op(string name = null)
=> gen_control_flow_ops.no_op(name: name);

private static Operation _GroupControlDeps(string dev, Operation[] deps, string name = null)
{
return with(ops.control_dependencies(deps), ctl =>


+ 7
- 0
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -12,6 +12,13 @@ namespace Tensorflow
public static OpDefLibrary _op_def_lib = new OpDefLibrary();
public static Execute _execute = new Execute();

public static Tensor concat(Tensor[] values, int axis, string name = null)
{
var _op = _op_def_lib._apply_op_helper("ConcatV2", name: name, args: new { values, axis });

return _op.outputs[0];
}

public static Tensor expand_dims(Tensor input, int axis, string name = null)
{
var _op = _op_def_lib._apply_op_helper("ExpandDims", name: name, args: new { input, dim = axis });


+ 13
- 0
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -428,5 +428,18 @@ namespace Tensorflow
return _op.outputs[0];
}
/// <summary>
/// Computes reciprocal of square root of x element-wise.
/// </summary>
/// <param name="x"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor rsqrt(Tensor x, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Rsqrt", name, new { x });
return _op.outputs[0];
}
}
}

+ 19
- 1
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -63,6 +63,13 @@ namespace Tensorflow
return x;
});
}

public static Tensor equal<Tx, Ty>(Tx x, Ty y, string name = null)
=> gen_math_ops.equal(x, y, name: name);

public static Tensor multiply(Tensor x, Tensor y, string name = null)
=> gen_math_ops.mul(x, y, name: name);

/// <summary>
/// Computes the mean of elements across dimensions of a tensor.
/// Reduces `input_tensor` along the dimensions given in `axis`.
@@ -327,7 +334,15 @@ namespace Tensorflow
return range(0, rank, 1);
}
}

/// <summary>
/// Computes reciprocal of square root of x element-wise.
/// </summary>
/// <param name="x"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor rsqrt(Tensor x, string name = null)
=> gen_math_ops.rsqrt(x, name: name);

public static Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range" )
{
@@ -373,6 +388,9 @@ namespace Tensorflow
});
}

public static Tensor maximum<Tx, Ty>(Tx x, Ty y, string name = null)
=> gen_math_ops.maximum(x, y, name: name);

public static Tensor matmul(Tensor a, Tensor b,
bool transpose_a = false, bool transpose_b = false,
bool adjoint_a = false, bool adjoint_b = false,


+ 22
- 0
src/TensorFlowNET.Core/Operations/nn_impl.py.cs View File

@@ -7,6 +7,28 @@ namespace Tensorflow
{
public class nn_impl : Python
{
/// <summary>
/// Normalizes along dimension `axis` using an L2 norm.
/// </summary>
/// <param name="x"></param>
/// <param name="axis"></param>
/// <param name="epsilon"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor l2_normalize(Tensor x,
int axis = 0,
float epsilon = 1e-12f,
string name = null)
{
return with(ops.name_scope(name, "", new { x }), scope =>
{
x = ops.convert_to_tensor(x, name: "x");
var square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims: true);
var x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon));
return math_ops.multiply(x, x_inv_norm, name: name);
});
}

/// <summary>
/// Calculate the mean and variance of `x`
/// </summary>


+ 2
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs View File

@@ -28,9 +28,11 @@ namespace Tensorflow
public static Tensor operator %(Tensor x, Tensor y) => BinaryOpWrapper("mod", x, y);

public static Tensor operator >(Tensor x, int y) => gen_math_ops.greater(x, y);
public static Tensor operator >=(Tensor x, Tensor y) => gen_math_ops.greater_equal(x, y);
public static Tensor operator >(Tensor x, float y) => gen_math_ops.greater(x, y);
public static Tensor operator >(Tensor x, double y) => gen_math_ops.greater(x, y);
public static Tensor operator <(Tensor x, int y) => gen_math_ops.less(x, y);
public static Tensor operator <=(Tensor x, Tensor y) => gen_math_ops.less_equal(x, y);
public static Tensor operator <(Tensor x, float y) => gen_math_ops.less(x, y);
public static Tensor operator <(Tensor x, double y) => gen_math_ops.less(x, y);



+ 5
- 0
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -134,5 +134,10 @@ namespace Tensorflow
{
return type == TF_DataType.TF_HALF || type == TF_DataType.TF_FLOAT || type == TF_DataType.TF_DOUBLE;
}

public static bool is_ref_dtype(this TF_DataType type)
{
return (int)type > 100;
}
}
}

+ 2
- 2
src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs View File

@@ -47,12 +47,12 @@ namespace Tensorflow
/// <param name="validate_shape"></param>
/// <param name="use_locking"></param>
/// <param name="name"></param>
public static Tensor assign(Tensor tensor, object value,
public static Tensor assign(Tensor @ref, object value,
bool validate_shape = true,
bool use_locking = true,
string name = null)
{
var _op = _op_def_lib._apply_op_helper("Assign", name: name, args: new { @ref = tensor, value, validate_shape, use_locking });
var _op = _op_def_lib._apply_op_helper("Assign", name: name, args: new { @ref, value, validate_shape, use_locking });

var _result = _op.outputs;
var _inputs_flat = _op.inputs;


+ 19
- 4
src/TensorFlowNET.Core/Variables/state_ops.cs View File

@@ -19,12 +19,27 @@ namespace Tensorflow
TF_DataType dtype,
string name = "Variable",
string container = "",
string shared_name = "") => gen_state_ops.variable_v2(shape,
dtype,
name: name,
container: container,
string shared_name = "") => gen_state_ops.variable_v2(shape,
dtype,
name: name,
container: container,
shared_name: shared_name);

public static Tensor assign(Tensor @ref, object value,
bool validate_shape = true,
bool use_locking = true,
string name = null)
{
if (@ref.dtype.is_ref_dtype())
return gen_state_ops.assign(@ref,
value,
validate_shape: validate_shape,
use_locking: use_locking,
name: name);
else
throw new NotImplementedException("state_ops.assign");
}

public static Tensor assign_sub(RefVariable @ref,
Tensor value,
bool use_locking = false,


Loading…
Cancel
Save