Browse Source

change RefVariable to VariableV1

tags/v0.12
Oceania2018 6 years ago
parent
commit
4f366ca18a
21 changed files with 269 additions and 51 deletions
  1. +16
    -7
      src/TensorFlowNET.Core/APIs/keras.layers.cs
  2. +23
    -1
      src/TensorFlowNET.Core/Keras/Engine/Model.cs
  3. +35
    -13
      src/TensorFlowNET.Core/Keras/Engine/Sequential.cs
  4. +4
    -4
      src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
  5. +2
    -2
      src/TensorFlowNET.Core/Keras/Layers/Conv.cs
  6. +2
    -2
      src/TensorFlowNET.Core/Keras/Layers/Dense.cs
  7. +15
    -2
      src/TensorFlowNET.Core/Keras/Layers/Embedding.cs
  8. +34
    -2
      src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs
  9. +13
    -6
      src/TensorFlowNET.Core/Keras/Layers/Layer.cs
  10. +10
    -0
      src/TensorFlowNET.Core/Keras/Optimizers/IOptimizer.cs
  11. +14
    -0
      src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs
  12. +14
    -0
      src/TensorFlowNET.Core/Keras/Optimizers/RMSprop.cs
  13. +2
    -2
      src/TensorFlowNET.Core/Keras/backend.cs
  14. +2
    -2
      src/TensorFlowNET.Core/Layers/Layer.cs
  15. +1
    -1
      src/TensorFlowNET.Core/Operations/array_ops.py.cs
  16. +45
    -0
      src/TensorFlowNET.Core/Operations/embedding_ops.cs
  17. +1
    -1
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  18. +3
    -3
      src/TensorFlowNET.Core/Train/Trackable.cs
  19. +2
    -2
      src/TensorFlowNET.Core/Variables/RefVariable.cs
  20. +2
    -1
      src/TensorFlowNET.Core/Variables/VariableV1.cs
  21. +29
    -0
      test/TensorFlowNET.UnitTest/Keras/EmbeddingTest.cs

+ 16
- 7
src/TensorFlowNET.Core/APIs/keras.layers.cs View File

@@ -38,13 +38,22 @@ namespace Tensorflow
var batch_size = batch_shape[0]; var batch_size = batch_shape[0];
var shape = batch_shape.Skip(1).ToArray(); var shape = batch_shape.Skip(1).ToArray();


var input_layer = new InputLayer(
input_shape: shape,
batch_size: batch_size,
name: name,
dtype: dtype,
sparse: sparse,
input_tensor: tensor);
InputLayer input_layer = null;
if (batch_shape != null)
input_layer = new InputLayer(
batch_input_shape: batch_shape,
name: name,
dtype: dtype,
sparse: sparse,
input_tensor: tensor);
else
input_layer = new InputLayer(
input_shape: shape,
batch_size: batch_size,
name: name,
dtype: dtype,
sparse: sparse,
input_tensor: tensor);


var outputs = input_layer.inbound_nodes[0].output_tensors; var outputs = input_layer.inbound_nodes[0].output_tensors;




+ 23
- 1
src/TensorFlowNET.Core/Keras/Engine/Model.cs View File

@@ -1,11 +1,33 @@
namespace Tensorflow.Keras.Engine
using Tensorflow.Keras.Optimizers;

namespace Tensorflow.Keras.Engine
{ {
public class Model : Network public class Model : Network
{ {
bool _cloning;
bool _is_compiled;
string loss;
IOptimizer optimizer;

public Model(string name = null) public Model(string name = null)
: base(name: name) : base(name: name)
{ {


} }

public void compile(string optimizerName, string lossName)
{
switch (optimizerName)
{
case "rmsprop":
optimizer = new RMSprop();
break;
}

loss = lossName;
_is_compiled = true;

// Prepare list of loss functions, same size of model outputs.
}
} }
} }

+ 35
- 13
src/TensorFlowNET.Core/Keras/Engine/Sequential.cs View File

@@ -20,6 +20,9 @@ namespace Tensorflow.Keras.Engine
{ {
public class Sequential : Model, IObjectLife public class Sequential : Model, IObjectLife
{ {
bool _is_graph_network;
Tensor[] outputs;

public Sequential(string name = null) public Sequential(string name = null)
: base(name: name) : base(name: name)
{ {
@@ -42,21 +45,40 @@ namespace Tensorflow.Keras.Engine
var set_inputs = false; var set_inputs = false;
if(_layers.Count == 0) if(_layers.Count == 0)
{ {
var (batch_shape, dtype) = (layer._batch_input_shape, layer._dtype);
if(batch_shape != null)
if(layer is InputLayer)
{ {
// Instantiate an input layer.
var x = keras.layers.Input(
batch_shape: batch_shape,
dtype: dtype,
name: layer.name + "_input");

// This will build the current layer
// and create the node connecting the current layer
// to the input layer we just created.
layer.__call__(x);
set_inputs = true;

} }
else
{
var (batch_shape, dtype) = (layer._batch_input_shape, layer._dtype);
if (batch_shape != null)
{
// Instantiate an input layer.
var x = keras.layers.Input(
batch_shape: batch_shape,
dtype: dtype,
name: layer.name + "_input");

// This will build the current layer
// and create the node connecting the current layer
// to the input layer we just created.
layer.__call__(x);
set_inputs = true;
}
}

if (set_inputs)
{
// If an input layer (placeholder) is available.
// outputs = layer._inbound_nodes;
}

}

if (set_inputs || _is_graph_network)
{

} }
} }




+ 4
- 4
src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs View File

@@ -95,7 +95,7 @@ namespace Tensorflow.Keras.Layers
var param_shape = new int[] { input_shape.dims[axis[0]] }; var param_shape = new int[] { input_shape.dims[axis[0]] };


if (scale) if (scale)
gamma = add_weight("gamma",
gamma = (RefVariable)add_weight("gamma",
param_shape, param_shape,
dtype: param_dtype, dtype: param_dtype,
initializer: gamma_initializer, initializer: gamma_initializer,
@@ -104,7 +104,7 @@ namespace Tensorflow.Keras.Layers
throw new NotImplementedException("add_weight gamma"); throw new NotImplementedException("add_weight gamma");


if (center) if (center)
beta = add_weight("beta",
beta = (RefVariable)add_weight("beta",
param_shape, param_shape,
dtype: param_dtype, dtype: param_dtype,
initializer: beta_initializer, initializer: beta_initializer,
@@ -117,7 +117,7 @@ namespace Tensorflow.Keras.Layers
} }


moving_mean = add_weight("moving_mean",
moving_mean = (RefVariable)add_weight("moving_mean",
param_shape, param_shape,
dtype: param_dtype, dtype: param_dtype,
initializer: moving_mean_initializer, initializer: moving_mean_initializer,
@@ -125,7 +125,7 @@ namespace Tensorflow.Keras.Layers
trainable: false, trainable: false,
aggregation: VariableAggregation.Mean); aggregation: VariableAggregation.Mean);


moving_variance = add_weight("moving_variance",
moving_variance = (RefVariable)add_weight("moving_variance",
shape: param_shape, shape: param_shape,
dtype: param_dtype, dtype: param_dtype,
initializer: moving_variance_initializer, initializer: moving_variance_initializer,


+ 2
- 2
src/TensorFlowNET.Core/Keras/Layers/Conv.cs View File

@@ -75,13 +75,13 @@ namespace Tensorflow.Keras.Layers
input_shape.dims[input_shape.ndim + channel_axis] : input_shape.dims[input_shape.ndim + channel_axis] :
input_shape.dims[channel_axis]; input_shape.dims[channel_axis];
var kernel_shape = new int[] { kernel_size[0], kernel_size[1], input_dim, filters }; var kernel_shape = new int[] { kernel_size[0], kernel_size[1], input_dim, filters };
kernel = add_weight(name: "kernel",
kernel = (RefVariable)add_weight(name: "kernel",
shape: kernel_shape, shape: kernel_shape,
initializer: kernel_initializer, initializer: kernel_initializer,
trainable: true, trainable: true,
dtype: _dtype); dtype: _dtype);
if (use_bias) if (use_bias)
bias = add_weight(name: "bias",
bias = (RefVariable)add_weight(name: "bias",
shape: new int[] { filters }, shape: new int[] { filters },
initializer: bias_initializer, initializer: bias_initializer,
trainable: true, trainable: true,


+ 2
- 2
src/TensorFlowNET.Core/Keras/Layers/Dense.cs View File

@@ -55,14 +55,14 @@ namespace Tensorflow.Keras.Layers
var axes = new Dictionary<int, int>(); var axes = new Dictionary<int, int>();
axes[-1] = last_dim; axes[-1] = last_dim;
input_spec = new InputSpec(min_ndim: 2, axes: axes); input_spec = new InputSpec(min_ndim: 2, axes: axes);
kernel = add_weight(
kernel = (RefVariable)add_weight(
"kernel", "kernel",
shape: new int[] { last_dim, units }, shape: new int[] { last_dim, units },
initializer: kernel_initializer, initializer: kernel_initializer,
dtype: _dtype, dtype: _dtype,
trainable: true); trainable: true);
if (use_bias) if (use_bias)
bias = add_weight(
bias = (RefVariable)add_weight(
"bias", "bias",
shape: new int[] { units }, shape: new int[] { units },
initializer: bias_initializer, initializer: bias_initializer,


+ 15
- 2
src/TensorFlowNET.Core/Keras/Layers/Embedding.cs View File

@@ -23,20 +23,23 @@ namespace Tensorflow.Keras.Layers
private int input_dim; private int input_dim;
private int output_dim; private int output_dim;
private bool mask_zero; private bool mask_zero;
public RefVariable embeddings;
public VariableV1 embeddings;
public IInitializer embeddings_initializer; public IInitializer embeddings_initializer;
int input_length;


public Embedding(int input_dim, int output_dim, public Embedding(int input_dim, int output_dim,
IInitializer embeddings_initializer = null, IInitializer embeddings_initializer = null,
bool mask_zero = false, bool mask_zero = false,
TF_DataType dtype = TF_DataType.TF_FLOAT, TF_DataType dtype = TF_DataType.TF_FLOAT,
int[] input_shape = null) : base(dtype: dtype, input_shape: input_shape)
int[] input_shape = null,
int input_length = -1) : base(dtype: dtype, input_shape: input_shape ?? new[] { input_length })
{ {
this.input_dim = input_dim; this.input_dim = input_dim;
this.output_dim = output_dim; this.output_dim = output_dim;
this.embeddings_initializer = embeddings_initializer == null ? tf.uniform_initializer : embeddings_initializer; this.embeddings_initializer = embeddings_initializer == null ? tf.uniform_initializer : embeddings_initializer;
this.mask_zero = mask_zero; this.mask_zero = mask_zero;
supports_masking = mask_zero; supports_masking = mask_zero;
this.input_length = input_length;
} }


protected override void build(TensorShape input_shape) protected override void build(TensorShape input_shape)
@@ -46,5 +49,15 @@ namespace Tensorflow.Keras.Layers
name: "embeddings"); name: "embeddings");
built = true; built = true;
} }

protected override Tensor call(Tensor inputs, Tensor training = null)
{
var dtype = inputs.dtype;
if (dtype != tf.int32 && dtype != tf.int64)
inputs = math_ops.cast(inputs, tf.int32);

var @out = embedding_ops.embedding_lookup(embeddings, inputs);
return @out;
}
} }
} }

+ 34
- 2
src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs View File

@@ -15,6 +15,8 @@
******************************************************************************/ ******************************************************************************/


using System; using System;
using System.Collections.Generic;
using System.Linq;


namespace Tensorflow.Keras.Layers namespace Tensorflow.Keras.Layers
{ {
@@ -28,21 +30,47 @@ namespace Tensorflow.Keras.Layers
public bool is_placeholder; public bool is_placeholder;


public InputLayer(int[] input_shape = null, public InputLayer(int[] input_shape = null,
int[] batch_input_shape = null,
int? batch_size = null, int? batch_size = null,
TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype = TF_DataType.DtInvalid,
string name = null, string name = null,
bool sparse = false, bool sparse = false,
Tensor input_tensor = null)
Tensor input_tensor = null) : base(dtype: dtype, name: name)
{ {
built = true; built = true;
this.sparse = sparse; this.sparse = sparse;
this.batch_size = batch_size; this.batch_size = batch_size;
this.supports_masking = true; this.supports_masking = true;


if(batch_input_shape != null)
{
batch_size = batch_input_shape[0];
input_shape = batch_input_shape.Skip(1).ToArray();
}

// moved to base class
if (string.IsNullOrEmpty(name))
{
var prefix = "input";
name = prefix + '_' + backend.get_uid(prefix);
}

if (input_tensor == null) if (input_tensor == null)
{ {
var batch_input_shape = new int[] { batch_size.HasValue ? batch_size.Value : -1, -1 };
if(input_shape != null)
{
var dims = new List<int> { batch_size.HasValue ? batch_size.Value : -1 };
dims.AddRange(input_shape);
batch_input_shape = dims.ToArray();
}
else
{
batch_input_shape = null;
}

var graph = backend.get_graph().as_default();


// In graph mode, create a graph placeholder to call the layer on.
if (sparse) if (sparse)
{ {
throw new NotImplementedException("InputLayer sparse is true"); throw new NotImplementedException("InputLayer sparse is true");
@@ -59,6 +87,10 @@ namespace Tensorflow.Keras.Layers
_batch_input_shape = batch_input_shape; _batch_input_shape = batch_input_shape;
} }


// Create an input node to add to self.outbound_node
// and set output_tensors' _keras_history.
// input_tensor._keras_history = base_layer.KerasHistory(self, 0, 0)
// input_tensor._keras_mask = None
new Node(this, new Node(this,
inbound_layers: new Layer[0], inbound_layers: new Layer[0],
node_indices: new int[0], node_indices: new int[0],


+ 13
- 6
src/TensorFlowNET.Core/Keras/Layers/Layer.cs View File

@@ -51,7 +51,7 @@ namespace Tensorflow.Keras.Layers
/// </summary> /// </summary>
protected InputSpec input_spec; protected InputSpec input_spec;
protected bool supports_masking; protected bool supports_masking;
protected List<RefVariable> _trainable_weights;
protected List<VariableV1> _trainable_weights;
private string _name; private string _name;
public string name => _name; public string name => _name;
protected string _base_name; protected string _base_name;
@@ -65,6 +65,8 @@ namespace Tensorflow.Keras.Layers
private List<Node> _outbound_nodes; private List<Node> _outbound_nodes;
public List<Node> outbound_nodes => _outbound_nodes; public List<Node> outbound_nodes => _outbound_nodes;


float _initial_weights;

public Layer(bool trainable = true, public Layer(bool trainable = true,
string name = null, string name = null,
TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype = TF_DataType.DtInvalid,
@@ -81,13 +83,18 @@ namespace Tensorflow.Keras.Layers
this.supports_masking = false; this.supports_masking = false;


_init_set_name(name); _init_set_name(name);
_trainable_weights = new List<RefVariable>();
_trainable_weights = new List<VariableV1>();
_compute_previous_mask = false; _compute_previous_mask = false;
_updates = new List<Operation>(); _updates = new List<Operation>();


// Manage input shape information if passed. // Manage input shape information if passed.

_batch_input_shape = new int[] { -1, -1 };
if(input_shape != null)
{
var shapes = new List<int> { -1 };
shapes.AddRange(input_shape);
_batch_input_shape = shapes.ToArray();
}


_dtype = dtype; _dtype = dtype;


@@ -186,12 +193,12 @@ namespace Tensorflow.Keras.Layers
built = true; built = true;
} }


protected virtual RefVariable add_weight(string name,
protected virtual VariableV1 add_weight(string name,
int[] shape, int[] shape,
TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype = TF_DataType.DtInvalid,
IInitializer initializer = null, IInitializer initializer = null,
bool? trainable = null, bool? trainable = null,
Func<string, int[], TF_DataType, IInitializer, bool, RefVariable> getter = null)
Func<string, int[], TF_DataType, IInitializer, bool, VariableV1> getter = null)
{ {
if (dtype == TF_DataType.DtInvalid) if (dtype == TF_DataType.DtInvalid)
dtype = TF_DataType.TF_FLOAT; dtype = TF_DataType.TF_FLOAT;


+ 10
- 0
src/TensorFlowNET.Core/Keras/Optimizers/IOptimizer.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Optimizers
{
public interface IOptimizer
{
}
}

+ 14
- 0
src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Train;

namespace Tensorflow.Keras.Optimizers
{
/// <summary>
/// Updated base class for optimizers.
/// </summary>
public class OptimizerV2 : Trackable, IOptimizer
{
}
}

+ 14
- 0
src/TensorFlowNET.Core/Keras/Optimizers/RMSprop.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Optimizers
{
/// <summary>
/// Optimizer that implements the RMSprop algorithm.
/// </summary>
public class RMSprop : OptimizerV2
{

}
}

+ 2
- 2
src/TensorFlowNET.Core/Keras/backend.cs View File

@@ -42,12 +42,12 @@ namespace Tensorflow.Keras
/// Allows to give unique autogenerated names to layers, in a graph-specific way. /// Allows to give unique autogenerated names to layers, in a graph-specific way.
/// </summary> /// </summary>
public static Dictionary<Graph, Dictionary<(string, string), int>> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<(string, string), int>>(); public static Dictionary<Graph, Dictionary<(string, string), int>> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<(string, string), int>>();
public static Dictionary<string, RefVariable> _GRAPH_VARIABLES = new Dictionary<string, RefVariable>();
public static Dictionary<string, VariableV1> _GRAPH_VARIABLES = new Dictionary<string, VariableV1>();
public static Dictionary<string, Optimizer> _GRAPH_TF_OPTIMIZERS = new Dictionary<string, Optimizer>(); public static Dictionary<string, Optimizer> _GRAPH_TF_OPTIMIZERS = new Dictionary<string, Optimizer>();


public static _DummyEagerGraph _DUMMY_EAGER_GRAPH = new _DummyEagerGraph(); public static _DummyEagerGraph _DUMMY_EAGER_GRAPH = new _DummyEagerGraph();


public static void track_variable(RefVariable v)
public static void track_variable(VariableV1 v)
{ {
var graph = v.graph; var graph = v.graph;
_GRAPH_VARIABLES[graph.graph_key] = v; _GRAPH_VARIABLES[graph.graph_key] = v;


+ 2
- 2
src/TensorFlowNET.Core/Layers/Layer.cs View File

@@ -42,7 +42,7 @@ namespace Tensorflow.Layers
this._reuse = _reuse; this._reuse = _reuse;


// Avoid an incorrect lint error // Avoid an incorrect lint error
_trainable_weights = new List<RefVariable>();
_trainable_weights = new List<VariableV1>();
this.built = false; this.built = false;
_keras_style = false; _keras_style = false;
} }
@@ -109,7 +109,7 @@ namespace Tensorflow.Layers
/// <param name="synchronization"></param> /// <param name="synchronization"></param>
/// <param name="aggregation"></param> /// <param name="aggregation"></param>
/// <returns></returns> /// <returns></returns>
protected virtual RefVariable add_weight(string name,
protected virtual VariableV1 add_weight(string name,
int[] shape, int[] shape,
TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype = TF_DataType.DtInvalid,
IInitializer initializer = null, IInitializer initializer = null,


+ 1
- 1
src/TensorFlowNET.Core/Operations/array_ops.py.cs View File

@@ -600,7 +600,7 @@ namespace Tensorflow
return gen_array_ops.concat_v2(values, axis, name: name); return gen_array_ops.concat_v2(values, axis, name: name);
} }
public static Tensor gather(Tensor @params, Tensor indices, string name = null, int axis = 0)
public static Tensor gather<T1, T2>(T1 @params, T2 indices, string name = null, int axis = 0)
=> gen_array_ops.gather_v2(@params, indices, axis, name: name); => gen_array_ops.gather_v2(@params, indices, axis, name: name);
public static Tensor transpose<T1, T2>(T1 a, T2 perm, string name = "transpose", bool conjugate = false) public static Tensor transpose<T1, T2>(T1 a, T2 perm, string name = "transpose", bool conjugate = false)


+ 45
- 0
src/TensorFlowNET.Core/Operations/embedding_ops.cs View File

@@ -52,6 +52,38 @@ namespace Tensorflow
}); });
} }


/// <summary>
/// Helper function for embedding_lookup and _compute_sampled_logits.
/// </summary>
/// <param name="params"></param>
/// <param name="ids"></param>
/// <param name="partition_strategy"></param>
/// <param name="name"></param>
/// <param name="max_norm"></param>
/// <returns></returns>
public static Tensor _embedding_lookup_and_transform(VariableV1 @params,
Tensor ids,
string partition_strategy = "mod",
string name = null,
string max_norm = null)
{
return tf_with(ops.name_scope(name, "embedding_lookup", new { @params, ids }), scope =>
{
name = scope;
int np = 1;
ids = ops.convert_to_tensor(ids, name: "ids");
if (np == 1)
{
var gather = array_ops.gather(@params, ids, name: name);
var result = _clip(gather, ids, max_norm);

return array_ops.identity(result);
}

throw new NotImplementedException("_embedding_lookup_and_transform");
});
}

public static Tensor _embedding_lookup_and_transform(Tensor[] @params, public static Tensor _embedding_lookup_and_transform(Tensor[] @params,
Tensor ids, Tensor ids,
string partition_strategy = "mod", string partition_strategy = "mod",
@@ -98,5 +130,18 @@ namespace Tensorflow
name: name, name: name,
max_norm: max_norm); max_norm: max_norm);
} }

public static Tensor embedding_lookup(VariableV1 @params, Tensor ids,
string partition_strategy = "mod",
string name = null,
bool validate_indices = true,
string max_norm = null)
{
return _embedding_lookup_and_transform(@params: @params,
ids: ids,
partition_strategy: partition_strategy,
name: name,
max_norm: max_norm);
}
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -106,7 +106,7 @@ namespace Tensorflow
return _op.outputs[0]; return _op.outputs[0];
} }


public static Tensor gather_v2(Tensor @params, Tensor indices, int axis, string name = null)
public static Tensor gather_v2<T1, T2>(T1 @params, T2 indices, int axis, string name = null)
{ {
var _op = _op_def_lib._apply_op_helper("GatherV2", name: name, new { @params, indices, axis }); var _op = _op_def_lib._apply_op_helper("GatherV2", name: name, new { @params, indices, axis });




+ 3
- 3
src/TensorFlowNET.Core/Train/Trackable.cs View File

@@ -26,11 +26,11 @@ namespace Tensorflow.Train
/// Restore-on-create for a variable be saved with this `Checkpointable`. /// Restore-on-create for a variable be saved with this `Checkpointable`.
/// </summary> /// </summary>
/// <returns></returns> /// <returns></returns>
protected virtual RefVariable _add_variable_with_custom_getter(string name,
protected virtual VariableV1 _add_variable_with_custom_getter(string name,
int[] shape, int[] shape,
TF_DataType dtype = TF_DataType.TF_FLOAT, TF_DataType dtype = TF_DataType.TF_FLOAT,
IInitializer initializer = null, IInitializer initializer = null,
Func<string, int[], TF_DataType, IInitializer, bool, RefVariable> getter = null,
Func<string, int[], TF_DataType, IInitializer, bool, VariableV1> getter = null,
bool overwrite = false, bool overwrite = false,
bool trainable = false) bool trainable = false)
{ {
@@ -59,7 +59,7 @@ namespace Tensorflow.Train
// TODO // TODO
} }


protected RefVariable _track_checkpointable(RefVariable checkpointable, string name, bool overwrite = false)
protected VariableV1 _track_checkpointable(VariableV1 checkpointable, string name, bool overwrite = false)
{ {
return checkpointable; return checkpointable;
} }


+ 2
- 2
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -28,14 +28,14 @@ namespace Tensorflow
public Tensor _initial_value; public Tensor _initial_value;
public string _graph_key; public string _graph_key;
public bool _trainable; public bool _trainable;
public Tensor _variable;
public Tensor _snapshot; public Tensor _snapshot;
public bool _save_slice_info; public bool _save_slice_info;


private Operation _initializer_op; private Operation _initializer_op;
public override Operation initializer => _initializer_op; public override Operation initializer => _initializer_op;
public override Operation op => _variable.op; public override Operation op => _variable.op;
public Graph graph => _variable.graph;
public TF_DataType dtype => _variable.dtype; public TF_DataType dtype => _variable.dtype;
public TensorShape shape => tensor_util.to_shape(_variable.shape); public TensorShape shape => tensor_util.to_shape(_variable.shape);




+ 2
- 1
src/TensorFlowNET.Core/Variables/VariableV1.cs View File

@@ -34,7 +34,8 @@ namespace Tensorflow
public virtual Tensor graph_element { get; } public virtual Tensor graph_element { get; }
public virtual Operation op { get; } public virtual Operation op { get; }
public virtual Operation initializer { get; } public virtual Operation initializer { get; }

public Tensor _variable;
public Graph graph => _variable.graph;
public VariableV1(object initial_value = null, public VariableV1(object initial_value = null,
bool trainable = true, bool trainable = true,
List<string> collections = null, List<string> collections = null,


+ 29
- 0
test/TensorFlowNET.UnitTest/Keras/EmbeddingTest.cs View File

@@ -0,0 +1,29 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;
using NumSharp;

namespace TensorFlowNET.UnitTest.Keras
{
[TestClass]
public class EmbeddingTest
{
[TestMethod]
public void Embedding()
{
var model = new Sequential();
model.add(new Embedding(1000, 64, input_length: 10));
// the model will take as input an integer matrix of size (batch,
// input_length).
// the largest integer (i.e. word index) in the input should be no larger
// than 999 (vocabulary size).
// now model.output_shape == (None, 10, 64), where None is the batch
// dimension.
var input_array = np.random.randint(1000, size: (32, 10));
model.compile("rmsprop", "mse");
}
}
}

Loading…
Cancel
Save