Browse Source

update:keras.save_weights or keras.load_weights

tags/yolov3
dataangel Haiping 4 years ago
parent
commit
2f3a17d448
3 changed files with 181 additions and 11 deletions
  1. +8
    -2
      src/TensorFlowNET.Keras/Engine/Model.Training.cs
  2. +1
    -1
      src/TensorFlowNET.Keras/Losses/LogCosh.cs
  3. +172
    -8
      src/TensorFlowNET.Keras/Saving/hdf5_format.cs

+ 8
- 2
src/TensorFlowNET.Keras/Engine/Model.Training.cs View File

@@ -27,9 +27,15 @@ namespace Tensorflow.Keras.Engine
}
else
{
fdf5_format.load_weights_from_hdf5_group(fileId, Layers);
hdf5_format.load_weights_from_hdf5_group(fileId, Layers);
}
H5G.close(fileId);
Hdf5.CloseFile(fileId);
}
public void save_weights(string filepath, bool overwrite = true, string save_format = null, object options = null)
{
long fileId = Hdf5.CreateFile(filepath);
hdf5_format.save_weights_to_hdf5_group(fileId, Layers);
Hdf5.CloseFile(fileId);
}
}
}


+ 1
- 1
src/TensorFlowNET.Keras/Losses/LogCosh.cs View File

@@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Losses
public LogCosh(
string reduction = null,
string name = null) :
base(reduction: reduction, name: name == null ? "huber" : name){ }
base(reduction: reduction, name: name == null ? "log_cosh" : name){ }

public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
{


src/TensorFlowNET.Keras/Saving/fdf5_format.cs → src/TensorFlowNET.Keras/Saving/hdf5_format.cs View File

@@ -8,12 +8,12 @@ using HDF5CSharp;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using System.Linq;
using Tensorflow.Util;
namespace Tensorflow.Keras.Saving
{
public class fdf5_format
public class hdf5_format
{
private static int HDF5_OBJECT_HEADER_LIMIT = 64512;
public static void load_model_from_hdf5(string filepath = "", Dictionary<string, object> custom_objects = null, bool compile = false)
{
long root = Hdf5.OpenFile(filepath,true);
@@ -79,10 +79,7 @@ namespace Tensorflow.Keras.Saving
{

}
public static void save_weights_to_hdf5_group(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
{

}
public static void load_weights_from_hdf5_group(long f, List<ILayer> layers)
{
string original_keras_version = "2.4.0";
@@ -136,9 +133,14 @@ namespace Tensorflow.Keras.Saving
var weight_values = new List<NDArray>();
long g = H5G.open(f, name);
var weight_names = load_attributes_from_hdf5_group(g, "weight_names");
var get_Name = "";
foreach (var i_ in weight_names)
{
(bool success, Array result) = Hdf5.ReadDataset<float>(g, i_);
get_Name = i_;
if (get_Name.IndexOf("/") > 1) {
get_Name = get_Name.Split('/')[1];
}
(bool success, Array result) = Hdf5.ReadDataset<float>(g, get_Name);
if (success)
weight_values.Add(np.array(result));
}
@@ -165,9 +167,171 @@ namespace Tensorflow.Keras.Saving
{

}
public static void save_attributes_to_hdf5_group(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
public static void save_weights_to_hdf5_group(long f, List<ILayer> layers)
{
List<string> layerName=new List<string>();
foreach (var layer in layers)
{
layerName.Add(layer.Name);
}
save_attributes_to_hdf5_group(f, "layer_names", layerName.ToArray());
Hdf5.WriteAttribute(f, "backend", "tensorflow");
Hdf5.WriteAttribute(f, "keras_version", "2.5.0");

long g = 0, crDataGroup=0;
List<IVariableV1> weights = new List<IVariableV1>();
//List<IVariableV1> weight_values = new List<IVariableV1>();
List<string> weight_names = new List<string>();
foreach (var layer in layers) {
weight_names = new List<string>();
g = Hdf5.CreateOrOpenGroup(f, Hdf5Utils.NormalizedName(layer.Name));
weights = _legacy_weights(layer);
//weight_values= keras.backend.batch_get_value(weights);
foreach (var weight in weights)
{
weight_names.Add(weight.Name);
}
save_attributes_to_hdf5_group(g, "weight_names", weight_names.ToArray());
Tensor tensor = null;
string get_Name = "";
foreach (var (name, val) in zip(weight_names, weights)) {
get_Name = name;
tensor = val.AsTensor();
if (get_Name.IndexOf("/") > 1)
{
get_Name = name.Split('/')[1];
crDataGroup = Hdf5.CreateOrOpenGroup(g, Hdf5Utils.NormalizedName(get_Name));
Hdf5.CloseGroup(crDataGroup);
}
WriteDataset(g, get_Name, tensor);
tensor = null;
}
Hdf5.CloseGroup(g);
weight_names = null;
}
weights = null;
// weight_values = null;

}
private static void save_attributes_to_hdf5_group(long f,string name ,Array data)
{
int num_chunks = 1;
var chunked_data = Split(data, num_chunks);
int getSize= 0;
string getType = data.Length>0?data.GetValue(0).GetType().Name.ToLower():"string";

switch (getType)
{
case "single":
getSize=sizeof(float);
break;
case "double":
getSize = sizeof(double);
break;
case "string":
getSize = -1;
break;
case "int32":
getSize = sizeof(int);
break;
case "int64":
getSize = sizeof(long);
break;
default:
getSize=-1;
break;
}
int getCount = chunked_data.Count;
if (getSize != -1) {
num_chunks = (int)Math.Ceiling((double)(getCount * getSize) / (double)HDF5_OBJECT_HEADER_LIMIT);
if (num_chunks > 1) chunked_data = Split(data, num_chunks);
}
if (num_chunks > 1)
{
foreach (var (chunk_id, chunk_data) in enumerate(chunked_data))
{
WriteAttrs(f, getType, $"{name}{chunk_id}", chunk_data.ToArray());
}

}
else {
WriteAttrs(f, getType,name, data);
}

}
private static void WriteDataset(long f, string name, Tensor data)
{
switch (data.dtype)
{
case TF_DataType.TF_FLOAT:
Hdf5.WriteDatasetFromArray<float>(f, name, data.numpy().ToMuliDimArray<float>());
break;
case TF_DataType.TF_DOUBLE:
Hdf5.WriteDatasetFromArray<double>(f, name, data.numpy().ToMuliDimArray<float>());
break;
case TF_DataType.TF_INT32:
Hdf5.WriteDatasetFromArray<int>(f, name, data.numpy().ToMuliDimArray<float>());
break;
case TF_DataType.TF_INT64:
Hdf5.WriteDatasetFromArray<long>(f, name, data.numpy().ToMuliDimArray<float>());
break;
default:
Hdf5.WriteDatasetFromArray<float>(f, name, data.numpy().ToMuliDimArray<float>());
break;
}
}
private static void WriteAttrs(long f,string typename, string name, Array data)
{
switch (typename)
{
case "single":
Hdf5.WriteAttributes<float>(f, name, data);
break;
case "double":
Hdf5.WriteAttributes<double>(f, name, data);
break;
case "string":
Hdf5.WriteAttributes<string>(f, name, data);
break;
case "int32":
Hdf5.WriteAttributes<int>(f, name, data);
break;
case "int64":
Hdf5.WriteAttributes<long>(f, name, data);
break;
default:
Hdf5.WriteAttributes<string>(f, name,data);
break;
}
}
private static List<List<object>> Split(Array list, int chunkSize)
{
var splitList = new List<List<object>>();
var chunkCount = (int)Math.Ceiling((double)list.Length / (double)chunkSize);

for (int c = 0; c < chunkCount; c++)
{
var skip = c * chunkSize;
var take = skip + chunkSize;
var chunk = new List<object>(chunkSize);

for (int e = skip; e < take && e < list.Length; e++)
{
chunk.Add(list.GetValue(e));
}
splitList.Add(chunk);
}

return splitList;
}
public static string[] load_attributes_from_hdf5_group(long group, string name)
{

Loading…
Cancel
Save