Browse Source

fix name issue in placeholder.

control_flow_ops.switch.
tags/v0.8.0
haiping008 6 years ago
parent
commit
c26fccf856
12 changed files with 188 additions and 16 deletions
  1. +23
    -0
      src/TensorFlowNET.Core/Framework/smart_module.cs
  2. +10
    -5
      src/TensorFlowNET.Core/Graphs/Graph.cs
  3. +8
    -5
      src/TensorFlowNET.Core/Keras/Engine/Layer.cs
  4. +52
    -1
      src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Conv.cs
  6. +20
    -0
      src/TensorFlowNET.Core/Keras/Utils/generic_utils.cs
  7. +9
    -0
      src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs
  8. +6
    -3
      src/TensorFlowNET.Core/Layers/Layer.cs
  9. +48
    -0
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
  10. +1
    -1
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  11. +7
    -0
      src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs
  12. +3
    -0
      src/TensorFlowNET.Core/Variables/VariableSynchronization.cs

+ 23
- 0
src/TensorFlowNET.Core/Framework/smart_module.cs View File

@@ -0,0 +1,23 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Framework
{
public class smart_module
{
public static object smart_cond(Tensor pred, Action true_fn = null, Action false_fn = null, string name = null)
{
return control_flow_ops.cond(pred,
true_fn: true_fn,
false_fn: false_fn,
name: name);
}

public static bool smart_constant_value(Tensor pred)
{
var pred_value = tensor_util.constant_value(pred);
return pred_value;
}
}
}

+ 10
- 5
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -20,7 +20,7 @@ namespace Tensorflow
private Dictionary<string, int> _names_in_use; private Dictionary<string, int> _names_in_use;
public int _version; public int _version;
private int _next_id_counter; private int _next_id_counter;
private List<String> _unfetchable_ops = new List<string>();
private List<Operation> _unfetchable_ops = new List<Operation>();
private List<Tensor> _unfeedable_tensors = new List<Tensor>(); private List<Tensor> _unfeedable_tensors = new List<Tensor>();


public string _name_stack = ""; public string _name_stack = "";
@@ -228,13 +228,13 @@ namespace Tensorflow


public bool is_fetchable<T>(T tensor_or_op) public bool is_fetchable<T>(T tensor_or_op)
{ {
if (tensor_or_op is Tensor)
if (tensor_or_op is Tensor tensor)
{ {
return !_unfetchable_ops.Contains((tensor_or_op as Tensor).name); ;
return !_unfetchable_ops.Contains(tensor); ;
} }
else if (tensor_or_op is Operation)
else if (tensor_or_op is Operation op)
{ {
return !_unfetchable_ops.Contains((tensor_or_op as Operation).name);
return !_unfetchable_ops.Contains(op);
} }


return false; return false;
@@ -372,6 +372,11 @@ namespace Tensorflow
_unfeedable_tensors.Add(tensor); _unfeedable_tensors.Add(tensor);
} }


public void prevent_fetching(Operation op)
{
_unfetchable_ops.Add(op);
}
public void Dispose() public void Dispose()
{ {
c_api.TF_DeleteGraph(_handle); c_api.TF_DeleteGraph(_handle);


+ 8
- 5
src/TensorFlowNET.Core/Keras/Engine/Layer.cs View File

@@ -48,6 +48,7 @@ namespace Tensorflow.Keras.Engine
} }


public Tensor __call__(Tensor inputs, public Tensor __call__(Tensor inputs,
Tensor training = null,
VariableScope scope = null) VariableScope scope = null)
{ {
var input_list = new Tensor[] { inputs }; var input_list = new Tensor[] { inputs };
@@ -73,7 +74,7 @@ namespace Tensorflow.Keras.Engine
// Symbolic execution on symbolic tensors. We will attempt to build // Symbolic execution on symbolic tensors. We will attempt to build
// the corresponding TF subgraph inside `backend.get_graph()` // the corresponding TF subgraph inside `backend.get_graph()`
var graph = backend.get_graph(); var graph = backend.get_graph();
outputs = call(inputs);
outputs = call(inputs, training: training);
_handle_activity_regularization(inputs, outputs); _handle_activity_regularization(inputs, outputs);
_set_mask_metadata(inputs, outputs, null); _set_mask_metadata(inputs, outputs, null);
} }
@@ -100,7 +101,7 @@ namespace Tensorflow.Keras.Engine
return null; return null;
} }


protected virtual Tensor call(Tensor inputs)
protected virtual Tensor call(Tensor inputs, Tensor training = null)
{ {
throw new NotImplementedException("Layer.call"); throw new NotImplementedException("Layer.call");
} }
@@ -143,13 +144,15 @@ namespace Tensorflow.Keras.Engine


protected virtual void _init_set_name(string name) protected virtual void _init_set_name(string name)
{ {
if (string.IsNullOrEmpty(name))
(_name, _base_name) = _make_unique_name();
string base_name = name;
if (name == null)
(_name, base_name) = _make_unique_name();
_base_name = base_name;
} }


protected virtual (string, string) _make_unique_name() protected virtual (string, string) _make_unique_name()
{ {
string base_name = "conv2d";
string base_name = generic_utils.to_snake_case(this.GetType().Name);
string name = base_layer_utils.unique_layer_name(base_name); string name = base_layer_utils.unique_layer_name(base_name);
return (name, base_name); return (name, base_name);
} }


+ 52
- 1
src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs View File

@@ -2,6 +2,7 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using Tensorflow.Keras.Utils;
using Tensorflow.Layers; using Tensorflow.Layers;


namespace Tensorflow.Keras.Layers namespace Tensorflow.Keras.Layers
@@ -25,6 +26,7 @@ namespace Tensorflow.Keras.Layers
private RefVariable gamma; private RefVariable gamma;
private RefVariable beta; private RefVariable beta;
private RefVariable moving_mean; private RefVariable moving_mean;
private RefVariable moving_variance;


public BatchNormalization(int axis = -1, public BatchNormalization(int axis = -1,
float momentum = 0.99f, float momentum = 0.99f,
@@ -103,7 +105,56 @@ namespace Tensorflow.Keras.Layers


moving_mean = add_weight("moving_mean", moving_mean = add_weight("moving_mean",
param_shape, param_shape,
dtype: param_dtype);
dtype: param_dtype,
initializer: moving_mean_initializer,
synchronization: VariableSynchronization.ON_READ,
trainable: false,
aggregation: VariableAggregation.MEAN);

moving_variance = add_weight("moving_variance",
shape: param_shape,
dtype: param_dtype,
initializer: moving_variance_initializer,
synchronization: VariableSynchronization.ON_READ,
trainable: false,
aggregation: VariableAggregation.MEAN);

if (renorm)
throw new NotImplementedException("build when renorm is true");

built = true;
}

protected override Tensor call(Tensor inputs, Tensor training = null)
{
Tensor outputs = null;

if (fused)
{
outputs = _fused_batch_norm(inputs, training: training);
}

throw new NotImplementedException("BatchNormalization call");
}

private Tensor _fused_batch_norm(Tensor inputs, Tensor training)
{
var beta = this.beta;
var gamma = this.gamma;

Action _fused_batch_norm_training = () =>
{

};

Action _fused_batch_norm_inference = () =>
{

};

tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference);

throw new NotImplementedException("_fused_batch_norm");
} }
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Conv.cs View File

@@ -91,7 +91,7 @@ namespace Tensorflow.Keras.Layers
built = true; built = true;
} }


protected override Tensor call(Tensor inputs)
protected override Tensor call(Tensor inputs, Tensor training = null)
{ {
var outputs = _convolution_op.__call__(inputs, kernel); var outputs = _convolution_op.__call__(inputs, kernel);
if (use_bias) if (use_bias)


+ 20
- 0
src/TensorFlowNET.Core/Keras/Utils/generic_utils.cs View File

@@ -0,0 +1,20 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Tensorflow.Keras.Utils
{
public class generic_utils
{
public static string to_snake_case(string name)
{
return string.Concat(name.Select((x, i) =>
{
return i > 0 && char.IsUpper(x) && !Char.IsDigit(name[i - 1]) ?
"_" + x.ToString() :
x.ToString();
})).ToLower();
}
}
}

+ 9
- 0
src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs View File

@@ -2,6 +2,7 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using Tensorflow.Framework;


namespace Tensorflow.Keras.Utils namespace Tensorflow.Keras.Utils
{ {
@@ -16,5 +17,13 @@ namespace Tensorflow.Keras.Utils
{ {
return true; return true;
} }

public static object smart_cond(Tensor pred, Action true_fn = null, Action false_fn = null, string name = null)
{
return smart_module.smart_cond(pred,
true_fn: true_fn,
false_fn: false_fn,
name: name);
}
} }
} }

+ 6
- 3
src/TensorFlowNET.Core/Layers/Layer.cs View File

@@ -29,10 +29,11 @@ namespace Tensorflow.Layers


public virtual Tensor apply(Tensor inputs, Tensor training = null) public virtual Tensor apply(Tensor inputs, Tensor training = null)
{ {
return __call__(inputs);
return __call__(inputs, training: training);
} }


public Tensor __call__(Tensor inputs, public Tensor __call__(Tensor inputs,
Tensor training = null,
VariableScope scope = null) VariableScope scope = null)
{ {
_set_scope(scope); _set_scope(scope);
@@ -51,7 +52,7 @@ namespace Tensorflow.Layers


Python.with(scope_context_manager, scope2 => _current_scope = scope2); Python.with(scope_context_manager, scope2 => _current_scope = scope2);
// Actually call layer // Actually call layer
var outputs = base.__call__(inputs);
var outputs = base.__call__(inputs, training: training);


// Update global default collections. // Update global default collections.
//_add_elements_to_collection(updates, ops.GraphKeys.UPDATE_OPS); //_add_elements_to_collection(updates, ops.GraphKeys.UPDATE_OPS);
@@ -63,7 +64,9 @@ namespace Tensorflow.Layers
int[] shape, int[] shape,
TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype = TF_DataType.DtInvalid,
IInitializer initializer = null, IInitializer initializer = null,
bool? trainable = null)
bool? trainable = null,
VariableSynchronization synchronization = VariableSynchronization.AUTO,
VariableAggregation aggregation = VariableAggregation.NONE)
{ {
var default_graph = ops.get_default_graph(); var default_graph = ops.get_default_graph();
Graph init_graph = null; Graph init_graph = null;


+ 48
- 0
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -135,5 +135,53 @@ namespace Tensorflow
else else
return gen_array_ops.identity(data, name: name); return gen_array_ops.identity(data, name: name);
} }

public static (Tensor, Tensor) cond(Tensor pred,
Action true_fn = null,
Action false_fn = null,
bool strict = false,
string name = null)
{
return with(ops.name_scope(name, "cond", new { pred }), delegate
{
// Add the Switch to the graph.
var (p_2, p_1) = @switch(pred, pred);
var pivot_1 = array_ops.identity(p_1, name: "switch_t");
var pivot_2 = array_ops.identity(p_2, name: "switch_f");
pred = array_ops.identity(pred, name: "pred_id");

// Disable the fetching of tensors that are only on one branch of cond.
foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred })
tensor.op.graph.prevent_fetching(tensor.op);

return (p_2, p_1);
});
}

/// <summary>
/// Forwards `data` to an output determined by `pred`.
/// </summary>
/// <param name="data"></param>
/// <param name="pred"></param>
/// <param name="dtype"></param>
/// <param name="name"></param>
public static (Tensor, Tensor) @switch(Tensor data,
Tensor pred,
TF_DataType dtype = TF_DataType.DtInvalid,
string name = null)
{
return with(ops.name_scope(name, "Switch", new { data, pred }), scope =>
{
name = scope;
data = ops.internal_convert_to_tensor_or_indexed_slices(data,
dtype: dtype,
name: "data",
as_ref: true);

pred = ops.convert_to_tensor(pred, name: "pred");

return gen_control_flow_ops.@switch(data, pred, name: name);
});
}
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -42,7 +42,7 @@ namespace Tensorflow


public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null) public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null)
{ {
var _op = _op_def_lib._apply_op_helper("Placeholder", args: new { dtype, shape });
var _op = _op_def_lib._apply_op_helper("Placeholder", name: name, args: new { dtype, shape });
var _result = _op.outputs; var _result = _op.outputs;
var _inputs_flat = _op.inputs; var _inputs_flat = _op.inputs;




+ 7
- 0
src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs View File

@@ -14,5 +14,12 @@ namespace Tensorflow


return _op; return _op;
} }

public static (Tensor, Tensor) @switch(Tensor data, Tensor pred, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Switch", name, new { data, pred });

return (_op.outputs[0], _op.outputs[1]);
}
} }
} }

+ 3
- 0
src/TensorFlowNET.Core/Variables/VariableSynchronization.cs View File

@@ -4,6 +4,9 @@ using System.Text;


namespace Tensorflow namespace Tensorflow
{ {
/// <summary>
/// Indicates when a distributed variable will be synced.
/// </summary>
public enum VariableSynchronization public enum VariableSynchronization
{ {
AUTO = 0, AUTO = 0,


Loading…
Cancel
Save