|
|
|
@@ -30,6 +30,9 @@ using Tensorflow.Training; |
|
|
|
using Tensorflow.Training.Saving.SavedModel; |
|
|
|
using Tensorflow.Util; |
|
|
|
using static Tensorflow.Binding; |
|
|
|
using Tensorflow.Framework; |
|
|
|
using Tensorflow.Sessions; |
|
|
|
|
|
|
|
|
|
|
|
namespace Tensorflow.Keras.Engine |
|
|
|
{ |
|
|
|
@@ -134,6 +137,62 @@ namespace Tensorflow.Keras.Engine |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
public virtual void set_weights(IEnumerable<NDArray> weights) |
|
|
|
{ |
|
|
|
if (Weights.Count() != weights.Count()) throw new ValueError( |
|
|
|
$"You called `set_weights` on layer \"{this.name}\"" + |
|
|
|
$"with a weight list of length {len(weights)}, but the layer was " + |
|
|
|
$"expecting {len(Weights)} weights."); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// check if the shapes are compatible |
|
|
|
var weight_index = 0; |
|
|
|
foreach(var w in weights) |
|
|
|
{ |
|
|
|
if (!Weights[weight_index].AsTensor().is_compatible_with(w)) |
|
|
|
{ |
|
|
|
throw new ValueError($"Layer weight shape {w.shape} not compatible with provided weight shape {Weights[weight_index].shape}"); |
|
|
|
} |
|
|
|
weight_index++; |
|
|
|
} |
|
|
|
|
|
|
|
if (tf.executing_eagerly()) |
|
|
|
{ |
|
|
|
foreach (var (this_w, v_w) in zip(Weights, weights)) |
|
|
|
this_w.assign(v_w, read_value: true); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
// TODO(Wanglongzhi2001):seems like there exist some bug in graph mode when define model, so uncomment the following when it fixed. |
|
|
|
|
|
|
|
//Tensors assign_ops = new Tensors(); |
|
|
|
//var feed_dict = new FeedDict(); |
|
|
|
|
|
|
|
//Graph g = tf.Graph().as_default(); |
|
|
|
//foreach (var (this_w, v_w) in zip(Weights, weights)) |
|
|
|
//{ |
|
|
|
// var tf_dtype = this_w.dtype; |
|
|
|
// var placeholder_shape = v_w.shape; |
|
|
|
// var assign_placeholder = tf.placeholder(tf_dtype, placeholder_shape); |
|
|
|
// var assign_op = this_w.assign(assign_placeholder); |
|
|
|
// assign_ops.Add(assign_op); |
|
|
|
// feed_dict.Add(assign_placeholder, v_w); |
|
|
|
//} |
|
|
|
//var sess = tf.Session().as_default(); |
|
|
|
//sess.run(assign_ops, feed_dict); |
|
|
|
|
|
|
|
//g.Exit(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
public List<NDArray> get_weights() |
|
|
|
{ |
|
|
|
List<NDArray > weights = new List<NDArray>(); |
|
|
|
weights.AddRange(Weights.ConvertAll(x => x.numpy())); |
|
|
|
return weights; |
|
|
|
} |
|
|
|
|
|
|
|
protected int id; |
|
|
|
public int Id => id; |
|
|
|
protected string name; |
|
|
|
|