Browse Source

Add set_weights and get_weights APIs

tags/v0.100.5-BERT-load
Wanglongzhi2001 2 years ago
parent
commit
426a55ce7b
3 changed files with 43 additions and 8 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  3. +41
    -6
      src/TensorFlowNET.Keras/Engine/Layer.cs

+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/ILayer.cs View File

@@ -19,7 +19,7 @@ namespace Tensorflow.Keras
List<IVariableV1> TrainableWeights { get; } List<IVariableV1> TrainableWeights { get; }
List<IVariableV1> NonTrainableWeights { get; } List<IVariableV1> NonTrainableWeights { get; }
List<IVariableV1> Weights { get; set; } List<IVariableV1> Weights { get; set; }
void set_weights(List<NDArray> weights);
void set_weights(IEnumerable<NDArray> weights);
List<NDArray> get_weights(); List<NDArray> get_weights();
Shape OutputShape { get; } Shape OutputShape { get; }
Shape BatchInputShape { get; } Shape BatchInputShape { get; }


+ 1
- 1
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -75,7 +75,7 @@ namespace Tensorflow
public List<IVariableV1> Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } public List<IVariableV1> Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }


public List<NDArray> get_weights() => throw new NotImplementedException(); public List<NDArray> get_weights() => throw new NotImplementedException();
public void set_weights(List<NDArray> weights) => throw new NotImplementedException();
public void set_weights(IEnumerable<NDArray> weights) => throw new NotImplementedException();
public List<IVariableV1> NonTrainableWeights => throw new NotImplementedException(); public List<IVariableV1> NonTrainableWeights => throw new NotImplementedException();


public Shape OutputShape => throw new NotImplementedException(); public Shape OutputShape => throw new NotImplementedException();


+ 41
- 6
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -30,6 +30,9 @@ using Tensorflow.Training;
using Tensorflow.Training.Saving.SavedModel; using Tensorflow.Training.Saving.SavedModel;
using Tensorflow.Util; using Tensorflow.Util;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using Tensorflow.Framework;
using Tensorflow.Sessions;



namespace Tensorflow.Keras.Engine namespace Tensorflow.Keras.Engine
{ {
@@ -134,21 +137,53 @@ namespace Tensorflow.Keras.Engine
} }
} }


public virtual void set_weights(List<NDArray> weights)
public virtual void set_weights(IEnumerable<NDArray> weights)
{ {
if (Weights.Count() != weights.Count()) throw new ValueError( if (Weights.Count() != weights.Count()) throw new ValueError(
$"You called `set_weights` on layer \"{this.name}\"" + $"You called `set_weights` on layer \"{this.name}\"" +
$"with a weight list of length {len(weights)}, but the layer was " + $"with a weight list of length {len(weights)}, but the layer was " +
$"expecting {len(Weights)} weights."); $"expecting {len(Weights)} weights.");
for (int i = 0; i < weights.Count(); i++)



// check if the shapes are compatible
var weight_index = 0;
foreach(var w in weights)
{ {
if (weights[i].shape != Weights[i].shape)
if (!Weights[weight_index].AsTensor().is_compatible_with(w))
{ {
throw new ValueError($"Layer weight shape {weights[i].shape} not compatible with provided weight shape {Weights[i].shape}");
throw new ValueError($"Layer weight shape {w.shape} not compatible with provided weight shape {Weights[weight_index].shape}");
} }
weight_index++;
}

if (tf.executing_eagerly())
{
foreach (var (this_w, v_w) in zip(Weights, weights))
this_w.assign(v_w, read_value: true);
}
else
{
// TODO(Wanglongzhi2001):seems like there exist some bug in graph mode when define model, so uncomment the following when it fixed.

//Tensors assign_ops = new Tensors();
//var feed_dict = new FeedDict();

//Graph g = tf.Graph().as_default();
//foreach (var (this_w, v_w) in zip(Weights, weights))
//{
// var tf_dtype = this_w.dtype;
// var placeholder_shape = v_w.shape;
// var assign_placeholder = tf.placeholder(tf_dtype, placeholder_shape);
// var assign_op = this_w.assign(assign_placeholder);
// assign_ops.Add(assign_op);
// feed_dict.Add(assign_placeholder, v_w);
//}
//var sess = tf.Session().as_default();
//sess.run(assign_ops, feed_dict);

//g.Exit();
} }
foreach (var (this_w, v_w) in zip(Weights, weights))
this_w.assign(v_w, read_value: true);
} }


public List<NDArray> get_weights() public List<NDArray> get_weights()


Loading…
Cancel
Save