Browse Source

x_emb shape is not correct #189

tags/v0.8.0
haiping008 6 years ago
parent
commit
bdd9beca6b
10 changed files with 240 additions and 61 deletions
  1. +1
    -0
      src/TensorFlowNET.Core/APIs/tf.init.cs
  2. +2
    -0
      src/TensorFlowNET.Core/Graphs/Graph.cs
  3. +5
    -2
      src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs
  4. +24
    -51
      src/TensorFlowNET.Core/Keras/Engine/Layer.cs
  5. +10
    -1
      src/TensorFlowNET.Core/Keras/Layers/Conv.cs
  6. +20
    -0
      src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs
  7. +132
    -0
      src/TensorFlowNET.Core/Layers/Layer.cs
  8. +2
    -2
      src/TensorFlowNET.Core/Operations/embedding_ops.cs
  9. +8
    -0
      src/TensorFlowNET.Core/Variables/_VariableScopeStore.cs
  10. +36
    -5
      src/TensorFlowNET.Core/Variables/variable_scope.py.cs

+ 1
- 0
src/TensorFlowNET.Core/APIs/tf.init.cs View File

@@ -21,6 +21,7 @@ namespace Tensorflow
public static variable_scope variable_scope(VariableScope scope,
string default_name = null,
object values = null,
bool? reuse = null,
bool auxiliary_name_scope = true) => new variable_scope(scope,
default_name,
values,


+ 2
- 0
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -37,6 +37,8 @@ namespace Tensorflow
/// </summary>
private Dictionary<string, object> _collections = new Dictionary<string, object>();

public bool building_function;

public Graph()
{
_handle = c_api.TF_NewGraph();


+ 5
- 2
src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs View File

@@ -9,9 +9,12 @@ namespace Tensorflow.Keras.Engine
/// </summary>
public class InputSpec
{
public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid)
{
public int ndim;

public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid,
int? ndim = null)
{
this.ndim = ndim.Value;
}
}
}

+ 24
- 51
src/TensorFlowNET.Core/Keras/Engine/Layer.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Utils;

namespace Tensorflow.Keras.Engine
{
@@ -12,77 +13,49 @@ namespace Tensorflow.Keras.Engine
/// </summary>
public class Layer : CheckpointableBase
{
protected bool trainable;
protected string _name;
protected TF_DataType _dtype;
protected Graph _graph;
protected string _base_name;
protected VariableScope _scope;
/// <summary>
/// A stateful layer is a layer whose updates are run during inference too,
/// for instance stateful RNNs.
/// </summary>
protected bool stateful;
/// <summary>
/// Indicates whether `build` needs to be called upon layer call, to create
/// the layer's weights.
/// </summary>
protected bool built;
/// <summary>
/// Provides information about which inputs are compatible with the layer.
/// </summary>
protected InputSpec input_spec;
protected bool supports_masking;

public Layer(bool trainable = true,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid)
{
this.trainable = trainable;
this.stateful = false;
this.built = false;
this.supports_masking = false;
_init_set_name(name);
}

public Tensor apply(Tensor inputs)
{
return __call__(inputs);
}

public Tensor __call__(Tensor inputs,
VariableScope scope = null)
{
_set_scope(scope);
_graph = ops._get_graph_from_inputs(new List<Tensor> { inputs }, graph: _graph);
var scope_context_manager = tf.variable_scope(_scope);
var input_list = new Tensor[] { inputs };

// We will attempt to build a TF graph if & only if all inputs are symbolic.
// This is always the case in graph mode. It can also be the case in eager
// mode when all inputs can be traced back to `keras.Input()` (when building
// models using the functional API).
bool build_graph = tf_utils.are_all_symbolic_tensors(input_list);

// Handle Keras mask propagation from previous layer to current layer.
Python.with(new ops.name_scope(_name_scope()), delegate
{
if (!built)
{
_maybe_build(inputs);
}
});

throw new NotImplementedException("");
}

private void _init_set_name(string name)
protected virtual string _name_scope()
{
if (string.IsNullOrEmpty(name))
(_name, _base_name) = _make_unique_name();
return null;
}

private (string, string) _make_unique_name()
protected void _maybe_build(Tensor inputs)
{
string base_name = "conv2d";
string name = base_layer_utils.unique_layer_name(base_name);
return (name, base_name);
var input_list = new Tensor[] { inputs };
build(inputs.getShape());
}

private void _set_scope(VariableScope scope = null)
protected virtual void build(TensorShape input_shape)
{
if (_scope == null)
{
Python.with(tf.variable_scope(scope, default_name: _base_name), captured_scope =>
{
_scope = captured_scope;
});
}

}
}
}

+ 10
- 1
src/TensorFlowNET.Core/Keras/Layers/Conv.cs View File

@@ -6,7 +6,7 @@ using Tensorflow.Operations.Activation;

namespace Tensorflow.Keras.Layers
{
public class Conv : Layer
public class Conv : Tensorflow.Layers.Layer
{
protected int rank;
protected int filters;
@@ -45,6 +45,15 @@ namespace Tensorflow.Keras.Layers
this.use_bias = use_bias;
this.kernel_initializer = kernel_initializer;
this.bias_initializer = bias_initializer;
input_spec = new InputSpec(ndim: rank + 2);
}

protected override void build(TensorShape input_shape)
{
int channel_axis = data_format == "channels_first" ? 1 : -1;
int input_dim = input_shape.Dimensions[input_shape.NDim - 1];
var kernel_shape = new int[] { kernel_size[0], kernel_size[1], input_dim, filters };
add_weight();
}
}
}

+ 20
- 0
src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs View File

@@ -0,0 +1,20 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Tensorflow.Keras.Utils
{
public class tf_utils
{
public static bool are_all_symbolic_tensors(Tensor[] tensors)
{
return tensors.Select(x => is_symbolic_tensor(x)).Count() == tensors.Length;
}

public static bool is_symbolic_tensor(Tensor tensor)
{
return true;
}
}
}

+ 132
- 0
src/TensorFlowNET.Core/Layers/Layer.cs View File

@@ -0,0 +1,132 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Layers
{
public class Layer : Keras.Engine.Layer
{
protected bool trainable;
protected string _name;
protected TF_DataType _dtype;
protected Graph _graph;
protected string _base_name;
protected VariableScope _scope;
protected VariableScope _current_scope;
/// <summary>
/// A stateful layer is a layer whose updates are run during inference too,
/// for instance stateful RNNs.
/// </summary>
protected bool stateful;
/// <summary>
/// Provides information about which inputs are compatible with the layer.
/// </summary>
protected InputSpec input_spec;
protected bool supports_masking;
protected bool? _reuse;

public Layer(bool trainable = true,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,
bool? _reuse = null)
{
this.trainable = trainable;
this.stateful = false;
this._reuse = _reuse;
this.built = false;
this.supports_masking = false;
_init_set_name(name);
}

public Tensor apply(Tensor inputs)
{
return __call__(inputs);
}

public Tensor __call__(Tensor inputs,
VariableScope scope = null)
{
_set_scope(scope);
_graph = ops._get_graph_from_inputs(new List<Tensor> { inputs }, graph: _graph);

variable_scope scope_context_manager = null;
if (built)
{

}
else
{
scope_context_manager = tf.variable_scope(_scope,
auxiliary_name_scope: false);
}

Python.with(scope_context_manager, scope2 => _current_scope = scope2);
// Actually call layer
var outputs = base.__call__(inputs);

throw new NotImplementedException("");
}

protected virtual void add_weight()
{
var default_graph = ops.get_default_graph();
Graph init_graph = null;
RefVariable[] existing_variables = null;

if (default_graph.building_function)
{
throw new NotImplementedException("add_weight");
}
else
{
init_graph = default_graph;
existing_variables = variables.global_variables().ToArray();
}

var dtype = TF_DataType.TF_FLOAT;
_set_scope();
var reuse = built || (_reuse != null && _reuse.Value);
Python.with(tf.variable_scope(_scope,
reuse: reuse,
auxiliary_name_scope: false), scope =>
{
_current_scope = scope;
Python.with(new ops.name_scope(_name_scope()), delegate
{


});
});
}

private void _init_set_name(string name)
{
if (string.IsNullOrEmpty(name))
(_name, _base_name) = _make_unique_name();
}

private (string, string) _make_unique_name()
{
string base_name = "conv2d";
string name = base_layer_utils.unique_layer_name(base_name);
return (name, base_name);
}

protected override string _name_scope()
{
return _current_scope.original_name_scope;
}

private void _set_scope(VariableScope scope = null)
{
if (_scope == null)
{
Python.with(tf.variable_scope(scope, default_name: _base_name), captured_scope =>
{
_scope = captured_scope;
});
}
}
}
}

+ 2
- 2
src/TensorFlowNET.Core/Operations/embedding_ops.cs View File

@@ -28,7 +28,7 @@ namespace Tensorflow
if(np == 1)
{
var gather = array_ops.gather(@params, ids, name: name);
var result = _clip(@params, ids, max_norm);
var result = _clip(gather, ids, max_norm);

return array_ops.identity(result);
}
@@ -37,7 +37,7 @@ namespace Tensorflow
});
}

public static Tensor _clip(RefVariable @params, Tensor ids, string max_norm = null)
public static Tensor _clip(Tensor @params, Tensor ids, string max_norm = null)
{
if (max_norm == null)
return @params;


+ 8
- 0
src/TensorFlowNET.Core/Variables/_VariableScopeStore.cs View File

@@ -22,5 +22,13 @@ namespace Tensorflow
else
variable_scopes_count[scope_name] = 1;
}

public int variable_scope_count(string scope_name)
{
if (variable_scopes_count.ContainsKey(scope_name))
return variable_scopes_count[scope_name];
else
return 0;
}
}
}

+ 36
- 5
src/TensorFlowNET.Core/Variables/variable_scope.py.cs View File

@@ -5,6 +5,9 @@ using System.Text;

namespace Tensorflow
{
/// <summary>
/// A context manager for defining ops that creates variables (layers).
/// </summary>
public class variable_scope : IPython
{
public static string _VARSTORE_KEY = "__variable_store";
@@ -20,17 +23,19 @@ namespace Tensorflow
private ops.name_scope _current_name_scope;
private bool _auxiliary_name_scope;
private PureVariableScope _cached_pure_variable_scope;
private bool? _reuse;

public variable_scope(string name,
string default_name = "",
object values = null,
bool? reuse = null,
bool auxiliary_name_scope = true)
{
_name = name;
_default_name = default_name;
_values = values;
_current_name_scope = null;
_reuse = reuse;
_use_resource = false;
if (_default_name == null && _name == null)
throw new TypeError("If default_name is None then name is required");
@@ -41,13 +46,14 @@ namespace Tensorflow
public variable_scope(VariableScope scope,
string default_name = "",
object values = null,
bool? reuse = null,
bool auxiliary_name_scope = true)
{
_scope = scope;
_default_name = default_name;
_values = values;
_current_name_scope = null;
_reuse = reuse;
_use_resource = false;
if (_default_name == null && _scope == null)
throw new TypeError("If default_name is None then scope is required");
@@ -63,6 +69,9 @@ namespace Tensorflow
private VariableScope _enter_scope_uncached()
{
ops.name_scope current_name_scope;
PureVariableScope pure_variable_scope = null;
VariableScope entered_pure_variable_scope;

if (_auxiliary_name_scope)
// Create a new name scope later
current_name_scope = null;
@@ -85,18 +94,40 @@ namespace Tensorflow
var current_name_scope_name = current_name_scope;
_current_name_scope = current_name_scope;
string old_name_scope = current_name_scope_name;
PureVariableScope pure_variable_scope = null;
if(_scope == null)
pure_variable_scope = new PureVariableScope(_name, old_name_scope: old_name_scope);
else
pure_variable_scope = new PureVariableScope(_scope, old_name_scope: old_name_scope);
pure_variable_scope.__enter__();
VariableScope entered_pure_variable_scope = pure_variable_scope;
entered_pure_variable_scope = pure_variable_scope;
_cached_pure_variable_scope = pure_variable_scope;
return entered_pure_variable_scope;
}
else
{
current_name_scope = new ops.name_scope(_default_name);
current_name_scope.__enter__();
string current_name_scope_name = current_name_scope;
_current_name_scope = current_name_scope;
string unique_default_name = _get_unique_variable_scope(_default_name);
pure_variable_scope = new PureVariableScope(unique_default_name,
old_name_scope: current_name_scope_name);
pure_variable_scope.__enter__();
entered_pure_variable_scope = pure_variable_scope;
_cached_pure_variable_scope = pure_variable_scope;
return entered_pure_variable_scope;
}
}

throw new NotImplementedException("_enter_scope_uncached");
public static string _get_unique_variable_scope(string prefix)
{
var var_scope_store = get_variable_scope_store();
var current_scope = get_variable_scope();
string name = !string.IsNullOrEmpty(current_scope._name) ? current_scope._name + "/" + prefix : prefix;
if (var_scope_store.variable_scope_count(name) == 0)
return prefix;
throw new NotImplementedException("_get_unique_variable_scope");
}

public static RefVariable default_variable_creator(object initial_value,


Loading…
Cancel
Save