From 35e070db8f4fd07a261043c5f16c340423491c3b Mon Sep 17 00:00:00 2001 From: dataangel Date: Sun, 17 Jan 2021 13:55:04 +0800 Subject: [PATCH] Update hdf5_format.cs --- src/TensorFlowNET.Keras/Saving/hdf5_format.cs | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/src/TensorFlowNET.Keras/Saving/hdf5_format.cs b/src/TensorFlowNET.Keras/Saving/hdf5_format.cs index c7abc046..f0a3d4b1 100644 --- a/src/TensorFlowNET.Keras/Saving/hdf5_format.cs +++ b/src/TensorFlowNET.Keras/Saving/hdf5_format.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Collections.Generic; using System.Text; using HDF.PInvoke; @@ -136,11 +136,7 @@ namespace Tensorflow.Keras.Saving var get_Name = ""; foreach (var i_ in weight_names) { - get_Name = i_; - if (get_Name.IndexOf("/") > 1) { - get_Name = get_Name.Split('/')[1]; - } - (bool success, Array result) = Hdf5.ReadDataset(g, get_Name, alternativeName: i_); + (bool success, Array result) = Hdf5.ReadDataset(g, i_); if (success) weight_values.Add(np.array(result)); } @@ -193,17 +189,19 @@ namespace Tensorflow.Keras.Saving } save_attributes_to_hdf5_group(g, "weight_names", weight_names.ToArray()); Tensor tensor = null; - string get_Name = ""; foreach (var (name, val) in zip(weight_names, weights)) { - get_Name = name; + tensor = val.AsTensor(); - if (get_Name.IndexOf("/") > 1) + if (name.IndexOf("/") > 1) { - get_Name = name.Split('/')[1]; - crDataGroup = Hdf5.CreateOrOpenGroup(g, Hdf5Utils.NormalizedName(get_Name)); + crDataGroup = Hdf5.CreateOrOpenGroup(g, Hdf5Utils.NormalizedName(name.Split('/')[0])); + WriteDataset(crDataGroup, name.Split('/')[1], tensor); Hdf5.CloseGroup(crDataGroup); } - WriteDataset(g, get_Name, tensor); + else { + WriteDataset(crDataGroup, name, tensor); + } + tensor = null; } Hdf5.CloseGroup(g);