Browse Source

assign_sub, smart_cond

tags/v0.8.0
haiping008 6 years ago
parent
commit
16a59afe37
23 changed files with 284 additions and 52 deletions
  1. +26
    -0
      src/TensorFlowNET.Core/APIs/tf.layers.cs
  2. +16
    -13
      src/TensorFlowNET.Core/APIs/tf.nn.cs
  3. +7
    -4
      src/TensorFlowNET.Core/Framework/smart_module.cs
  4. +9
    -0
      src/TensorFlowNET.Core/Keras/Engine/Layer.cs
  5. +37
    -4
      src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
  6. +24
    -0
      src/TensorFlowNET.Core/Keras/Layers/MaxPooling2D.cs
  7. +33
    -0
      src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs
  8. +15
    -0
      src/TensorFlowNET.Core/Keras/Utils/conv_utils.cs
  9. +8
    -3
      src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs
  10. +14
    -1
      src/TensorFlowNET.Core/Layers/Layer.cs
  11. +11
    -2
      src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
  12. +7
    -2
      src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
  13. +34
    -8
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
  14. +8
    -1
      src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs
  15. +1
    -1
      src/TensorFlowNET.Core/Operations/nn_impl.py.cs
  16. +0
    -8
      src/TensorFlowNET.Core/Python.cs
  17. +2
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
  18. +2
    -1
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  19. +2
    -1
      src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs
  20. +11
    -1
      src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs
  21. +8
    -0
      src/TensorFlowNET.Core/Variables/state_ops.cs
  22. +4
    -0
      src/TensorFlowNET.Core/ops.GraphKeys.cs
  23. +5
    -1
      test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs

+ 26
- 0
src/TensorFlowNET.Core/APIs/tf.layers.cs View File

@@ -100,6 +100,32 @@ namespace Tensorflow

return layer.apply(inputs, training: training);
}

/// <summary>
/// Max pooling layer for 2D inputs (e.g. images).
/// </summary>
/// <param name="inputs">The tensor over which to pool. Must have rank 4.</param>
/// <param name="pool_size"></param>
/// <param name="strides"></param>
/// <param name="padding"></param>
/// <param name="data_format"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor max_pooling2d(Tensor inputs,
int[] pool_size,
int[] strides,
string padding = "valid",
string data_format = "channels_last",
string name = null)
{
var layer = new MaxPooling2D(pool_size: pool_size,
strides: strides,
padding: padding,
data_format: data_format,
name: name);

return layer.apply(inputs);
}
}
}
}

+ 16
- 13
src/TensorFlowNET.Core/APIs/tf.nn.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Operations;
using Tensorflow.Operations.Activation;

namespace Tensorflow
@@ -27,19 +28,21 @@ namespace Tensorflow

public static IActivation relu => new relu();

public static (Tensor, Tensor, Tensor) fused_batch_norm(Tensor x,
RefVariable scale,
RefVariable offset,
Tensor mean = null,
Tensor variance = null,
float epsilon = 0.001f,
string data_format = "NHWC",
bool is_training = true,
string name = null) => nn_impl.fused_batch_norm(x, scale, offset, mean, variance,
epsilon: epsilon,
data_format: data_format,
is_training: is_training,
name: name);
public static Tensor[] fused_batch_norm(Tensor x,
RefVariable scale,
RefVariable offset,
Tensor mean = null,
Tensor variance = null,
float epsilon = 0.001f,
string data_format = "NHWC",
bool is_training = true,
string name = null) => nn_impl.fused_batch_norm(x, scale, offset, mean, variance,
epsilon: epsilon,
data_format: data_format,
is_training: is_training,
name: name);

public static Tensor max_pool() => gen_nn_ops.max_pool();
}
}
}

+ 7
- 4
src/TensorFlowNET.Core/Framework/smart_module.cs View File

@@ -6,9 +6,9 @@ namespace Tensorflow.Framework
{
public class smart_module
{
public static object smart_cond(Tensor pred,
Func<(Tensor, Tensor, Tensor)> true_fn = null,
Func<(Tensor, Tensor, Tensor)> false_fn = null,
public static Tensor[] smart_cond<T>(Tensor pred,
Func<T[]> true_fn = null,
Func<T[]> false_fn = null,
string name = null)
{
return control_flow_ops.cond(pred,
@@ -17,9 +17,12 @@ namespace Tensorflow.Framework
name: name);
}

public static bool smart_constant_value(Tensor pred)
public static bool? smart_constant_value(Tensor pred)
{
var pred_value = tensor_util.constant_value(pred);
if (pred_value is null)
return null;

return pred_value;
}
}


+ 9
- 0
src/TensorFlowNET.Core/Keras/Engine/Layer.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Keras.Utils;

@@ -34,6 +35,7 @@ namespace Tensorflow.Keras.Engine
protected string _name;
protected string _base_name;
protected bool _compute_previous_mask;
protected List<Operation> _updates;

public Layer(bool trainable = true, string name = null, TF_DataType dtype = TF_DataType.DtInvalid)
{
@@ -45,6 +47,7 @@ namespace Tensorflow.Keras.Engine
_init_set_name(name);
_trainable_weights = new List<RefVariable>();
_compute_previous_mask = false;
_updates = new List<Operation>();
}

public Tensor __call__(Tensor inputs,
@@ -142,6 +145,12 @@ namespace Tensorflow.Keras.Engine
return variable;
}

protected virtual void add_update(Tensor[] updates, bool inputs = false)
{
var updates_op = updates.Select(x => x.op).ToArray();
_updates.AddRange(updates_op);
}

protected virtual void _init_set_name(string name)
{
string base_name = name;


+ 37
- 4
src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs View File

@@ -132,6 +132,7 @@ namespace Tensorflow.Keras.Layers
if (fused)
{
outputs = _fused_batch_norm(inputs, training: training);
return outputs;
}

throw new NotImplementedException("BatchNormalization call");
@@ -142,7 +143,7 @@ namespace Tensorflow.Keras.Layers
var beta = this.beta;
var gamma = this.gamma;

Func<(Tensor, Tensor, Tensor)> _fused_batch_norm_training = () =>
Func<Tensor[]> _fused_batch_norm_training = () =>
{
return tf.nn.fused_batch_norm(
inputs,
@@ -152,7 +153,7 @@ namespace Tensorflow.Keras.Layers
data_format: _data_format);
};

Func<(Tensor, Tensor, Tensor)> _fused_batch_norm_inference = () =>
Func<Tensor[]> _fused_batch_norm_inference = () =>
{
return tf.nn.fused_batch_norm(
inputs,
@@ -165,9 +166,41 @@ namespace Tensorflow.Keras.Layers
data_format: _data_format);
};

tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference);
var results = tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference);
var (output, mean, variance) = (results[0], results[1], results[2]);
var training_value = tf_utils.constant_value(training);

throw new NotImplementedException("_fused_batch_norm");
Tensor momentum_tensor;
if (training_value == null)
{
momentum_tensor = tf_utils.smart_cond(training,
() => new float[] { momentum }, () => new float[] { 1.0f })[0];
}
else
{
momentum_tensor = ops.convert_to_tensor(momentum);
}
if(training_value == null)
{
var mean_update = _assign_moving_average(moving_mean, mean, momentum_tensor);
var variance_update = _assign_moving_average(moving_variance, variance, momentum_tensor);
add_update(new Tensor[] { mean_update }, inputs: true);
add_update(new Tensor[] { variance_update }, inputs: true);
}

return output;
}

public Tensor _assign_moving_average(RefVariable variable, Tensor value, Tensor momentum)
{
return Python.with(ops.name_scope(null, "AssignMovingAvg", new { variable, value, momentum }), scope =>
{
// var cm = ops.colocate_with(variable);
var decay = ops.convert_to_tensor(1.0f - momentum, name: "decay");
var update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay;
return state_ops.assign_sub(variable, update_delta, name: scope);
});
}
}
}

+ 24
- 0
src/TensorFlowNET.Core/Keras/Layers/MaxPooling2D.cs View File

@@ -0,0 +1,24 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.tf;

namespace Tensorflow.Keras.Layers
{
public class MaxPooling2D : Pooling2D
{
public MaxPooling2D(
int[] pool_size,
int[] strides,
string padding = "valid",
string data_format = null,
string name = null) : base(nn.max_pool, pool_size,
strides,
padding: padding,
data_format: data_format,
name: name)
{

}
}
}

+ 33
- 0
src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs View File

@@ -0,0 +1,33 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;

namespace Tensorflow.Keras.Layers
{
public class Pooling2D : Tensorflow.Layers.Layer
{
private Func<Tensor> pool_function;
private int[] pool_size;
private int[] strides;
private string padding;
private string data_format;
private InputSpec input_spec;

public Pooling2D(Func<Tensor> pool_function,
int[] pool_size,
int[] strides,
string padding = "valid",
string data_format = null,
string name = null) : base(name: name)
{
this.pool_function = pool_function;
this.pool_size = conv_utils.normalize_tuple(pool_size, 2, "pool_size");
this.strides = conv_utils.normalize_tuple(strides, 2, "strides");
this.padding = conv_utils.normalize_padding(padding);
this.data_format = conv_utils.normalize_data_format(data_format);
this.input_spec = new InputSpec(ndim: 4);
}
}
}

+ 15
- 0
src/TensorFlowNET.Core/Keras/Utils/conv_utils.cs View File

@@ -29,5 +29,20 @@ namespace Tensorflow.Keras.Utils
else
throw new ValueError($"Invalid data_format: {data_format}");
}

public static int[] normalize_tuple(int[] value, int n, string name)
{
return value;
}

public static string normalize_padding(string value)
{
return value.ToLower();
}

public static string normalize_data_format(string value)
{
return value.ToLower();
}
}
}

+ 8
- 3
src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs View File

@@ -13,14 +13,19 @@ namespace Tensorflow.Keras.Utils
return tensors.Select(x => is_symbolic_tensor(x)).Count() == tensors.Length;
}

public static bool? constant_value(Tensor pred)
{
return smart_module.smart_constant_value(pred);
}

public static bool is_symbolic_tensor(Tensor tensor)
{
return true;
}

public static object smart_cond(Tensor pred,
Func<(Tensor, Tensor, Tensor)> true_fn = null,
Func<(Tensor, Tensor, Tensor)> false_fn = null,
public static Tensor[] smart_cond<T>(Tensor pred,
Func<T[]> true_fn = null,
Func<T[]> false_fn = null,
string name = null)
{
return smart_module.smart_cond(pred,


+ 14
- 1
src/TensorFlowNET.Core/Layers/Layer.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Keras.Engine;

@@ -55,11 +56,23 @@ namespace Tensorflow.Layers
var outputs = base.__call__(inputs, training: training);

// Update global default collections.
//_add_elements_to_collection(updates, ops.GraphKeys.UPDATE_OPS);
_add_elements_to_collection(_updates.ToArray(), new string[] { ops.GraphKeys.UPDATE_OPS });

return outputs;
}

protected virtual void _add_elements_to_collection(Operation[] elements, string[] collection_list)
{
foreach(var name in collection_list)
{
var collection = ops.get_collection_ref(name) as List<object>;

foreach (var element in elements)
if (!collection.Contains(element))
collection.Add(element);
}
}

protected virtual RefVariable add_weight(string name,
int[] shape,
TF_DataType dtype = TF_DataType.DtInvalid,


+ 11
- 2
src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs View File

@@ -63,14 +63,23 @@ namespace Tensorflow.Operations
}
}

public (Tensor, Tensor, Tensor) BuildCondBranch(Func<(Tensor, Tensor, Tensor)> fn)
public (T[], Tensor[]) BuildCondBranch<T>(Func<T[]> fn)
{
// Add the subgraph defined by fn() to the graph.
var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
var original_result = fn();
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);

return original_result;
switch (original_result)
{
case Tensor[] results:
return (original_result, results);
case float[] fv:
var result = ops.convert_to_tensor(fv[0]);
return (original_result, new Tensor[] { result });
default:
return (original_result, new Tensor[0]);
}
}
}
}

+ 7
- 2
src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs View File

@@ -53,7 +53,7 @@ namespace Tensorflow.Operations
return _op.outputs[0];
}

public static (Tensor, Tensor, Tensor) _fused_batch_norm(Tensor x,
public static Tensor[] _fused_batch_norm(Tensor x,
Tensor scale,
Tensor offset,
Tensor mean,
@@ -75,7 +75,12 @@ namespace Tensorflow.Operations
is_training
});

return (_op.outputs[0], _op.outputs[1], _op.outputs[2]);
return _op.outputs;
}

public static Tensor max_pool()
{
throw new NotImplementedException("");
}
}
}

+ 34
- 8
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -137,9 +137,9 @@ namespace Tensorflow
return gen_array_ops.identity(data, name: name);
}

public static (Tensor, Tensor) cond(Tensor pred,
Func<(Tensor, Tensor, Tensor)> true_fn = null,
Func<(Tensor, Tensor, Tensor)> false_fn = null,
public static Tensor[] cond<T>(Tensor pred,
Func<T[]> true_fn = null,
Func<T[]> false_fn = null,
bool strict = false,
string name = null)
{
@@ -158,20 +158,46 @@ namespace Tensorflow
// Build the graph for the true branch in a new context.
var context_t = new CondContext(pred, pivot_1, branch: 1);
context_t.Enter();
var res_t = context_t.BuildCondBranch(true_fn);
var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn);
context_t.Exit();

// Build the graph for the false branch in a new context.
var context_f = new CondContext(pred, pivot_2, branch: 0);
context_f.Enter();
var res_f = context_f.BuildCondBranch(false_fn);
var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn);
context_f.Exit();

var res_t_flat = new Tensor[] { res_t.Item1, res_t.Item2, res_t.Item3 };
var res_f_flat = new Tensor[] { res_f.Item1, res_f.Item2, res_f.Item3 };
var res_t_flat = res_t;
var res_f_flat = res_f;

var merges = zip(res_f_flat, res_t_flat)
.Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 }))
.ToArray();

return (p_2, p_1);
merges = _convert_flows_to_tensorarrays(orig_res_t, merges);

ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t);
ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f);

return merges;
});
}

public static Tensor[] _convert_flows_to_tensorarrays<T>(T[] tensors_or_tensorarrays, Tensor[] tensors_or_flows)
{
// zip(tensors_or_tensorarrays, tensors_or_flows).Select((ta, t_or_flow) => ta).ToArray();
return tensors_or_flows;
}

public static Tensor merge(Tensor[] inputs, string name = null)
{
return with(ops.name_scope(name, "Merge", inputs), scope =>
{
name = scope;
inputs = inputs.Select(inp =>
ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref: true))
.ToArray();
return gen_control_flow_ops.merge(inputs, name).Item1;
});
}



+ 8
- 1
src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs View File

@@ -4,7 +4,7 @@ using System.Text;

namespace Tensorflow
{
public class gen_control_flow_ops
public class gen_control_flow_ops : Python
{
public static OpDefLibrary _op_def_lib = new OpDefLibrary();

@@ -21,5 +21,12 @@ namespace Tensorflow

return (_op.outputs[0], _op.outputs[1]);
}

public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Merge", name, new { inputs });

return (_op.outputs[0], _op.outputs[1]);
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Operations/nn_impl.py.cs View File

@@ -46,7 +46,7 @@ namespace Tensorflow
});
}

public static (Tensor, Tensor, Tensor) fused_batch_norm(Tensor x,
public static Tensor[] fused_batch_norm(Tensor x,
RefVariable scale,
RefVariable offset,
Tensor mean,


+ 0
- 8
src/TensorFlowNET.Core/Python.cs View File

@@ -118,14 +118,6 @@ namespace Tensorflow
{
object obj = propertyDescriptor.GetValue(dyn);
string name = propertyDescriptor.Name;
// avoid .net keyword
switch (name)
{
case "_ref_":
name = "ref";
break;
}

dictionary.Add(name, obj);
}
return dictionary;


+ 2
- 1
src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs View File

@@ -14,6 +14,7 @@ namespace Tensorflow
public static Tensor operator -(Tensor x, Tensor y) => BinaryOpWrapper("sub", x, y);
public static Tensor operator -(Tensor x, int y) => BinaryOpWrapper("sub", x, y);
public static Tensor operator -(Tensor x, double y) => BinaryOpWrapper("sub", x, y);
public static Tensor operator -(float x, Tensor y) => BinaryOpWrapper("Sub", x, y);

public static Tensor operator *(float x, Tensor y) => BinaryOpWrapper("mul", x, y);
public static Tensor operator *(double x, Tensor y) => BinaryOpWrapper("mul", x, y);
@@ -48,7 +49,7 @@ namespace Tensorflow
var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x");
var y1 = ops.convert_to_tensor(y, dtype: dtype, name: "y");

switch (name)
switch (name.ToLower())
{
case "add":
result = gen_math_ops.add(x1, y1, name: scope);


+ 2
- 1
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -38,7 +38,8 @@ namespace Tensorflow
{
return MakeNdarray(tensor.op.get_attr("value") as TensorProto);
}
throw new NotImplementedException("_ConstantValue");

return null;
}

public static NDArray MakeNdarray(TensorProto tensor)


+ 2
- 1
src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs View File

@@ -13,7 +13,8 @@ namespace Tensorflow
public static Tensor operator -(RefVariable x, int y) => op_helper("sub", x, y);
public static Tensor operator -(RefVariable x, float y) => op_helper("sub", x, y);
public static Tensor operator -(RefVariable x, double y) => op_helper("sub", x, y);
public static Tensor operator -(RefVariable x, Tensor y) => op_helper("sub", x, y);

private static Tensor op_helper<T>(string default_name, RefVariable x, T y)
{
var tensor1 = x.value();


+ 11
- 1
src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs View File

@@ -52,7 +52,7 @@ namespace Tensorflow
bool use_locking = true,
string name = null)
{
var _op = _op_def_lib._apply_op_helper("Assign", name: name, args: new { _ref_ = tensor, value, validate_shape, use_locking });
var _op = _op_def_lib._apply_op_helper("Assign", name: name, args: new { @ref = tensor, value, validate_shape, use_locking });

var _result = _op.outputs;
var _inputs_flat = _op.inputs;
@@ -66,5 +66,15 @@ namespace Tensorflow

return _result[0];
}

public static Tensor assign_sub(RefVariable @ref,
Tensor value,
bool use_locking = false,
string name = null)
{
var _op = _op_def_lib._apply_op_helper("AssignSub", name: name, args: new { @ref, value, use_locking });

return _op.outputs[0];
}
}
}

+ 8
- 0
src/TensorFlowNET.Core/Variables/state_ops.cs View File

@@ -24,5 +24,13 @@ namespace Tensorflow
name: name,
container: container,
shared_name: shared_name);

public static Tensor assign_sub(RefVariable @ref,
Tensor value,
bool use_locking = false,
string name = null) => gen_state_ops.assign_sub(@ref,
value,
use_locking: use_locking,
name: name);
}
}

+ 4
- 0
src/TensorFlowNET.Core/ops.GraphKeys.cs View File

@@ -47,6 +47,10 @@ namespace Tensorflow

// Used to store v2 summary names.
public static string _SUMMARY_COLLECTION = "_SUMMARY_V2";

// Key for control flow context.
public static string COND_CONTEXT = "cond_context";
public static string WHILE_CONTEXT = "while_context";
}
}
}

+ 5
- 1
test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs View File

@@ -93,7 +93,11 @@ namespace TensorFlowNET.Examples.TextClassification
if (max_pool)
{
// Max pooling
throw new NotImplementedException("conv_block");
return tf.layers.max_pooling2d(
conv,
pool_size: new int[] { 3, 1 },
strides: new int[] { 2, 1 },
padding: "SAME");
}
else
{


Loading…
Cancel
Save