Browse Source

Fix NeuralNetXor example.

tags/v0.100.4-load-saved-model
Haiping Chen 3 years ago
parent
commit
def066498d
4 changed files with 23 additions and 30 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  2. +1
    -1
      src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs
  3. +17
    -24
      src/TensorFlowNET.Keras/Saving/hdf5_format.cs
  4. +3
    -3
      src/TensorFlowNET.Keras/Tensorflow.Keras.csproj

+ 2
- 2
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -107,8 +107,8 @@ https://tensorflownet.readthedocs.io</Description>
</ItemGroup>

<ItemGroup>
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.144" />
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" />
<PackageReference Include="Protobuf.Text" Version="0.5.0" />
<PackageReference Include="Serilog.Sinks.Console" Version="4.0.1" />
<PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" />
</ItemGroup>
</Project>

+ 1
- 1
src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs View File

@@ -48,7 +48,7 @@ namespace Tensorflow.Keras.Optimizers
public void apply_gradients((Tensor, ResourceVariable) grads_and_vars,
string name = null,
bool experimental_aggregate_gradients = true)
=> apply_gradients(grads_and_vars,
=> apply_gradients(new[] { grads_and_vars },
name: name,
experimental_aggregate_gradients: experimental_aggregate_gradients);



+ 17
- 24
src/TensorFlowNET.Keras/Saving/hdf5_format.cs View File

@@ -84,23 +84,18 @@ namespace Tensorflow.Keras.Saving
{
string original_keras_version = "2.5.0";
string original_backend = null;
if (Hdf5.AttributeExists(f, "keras_version"))
{
var (success, attr) = Hdf5.ReadStringAttributes(f, "keras_version", "");
if (success)
original_keras_version = attr.First();
// keras version should be 2.5.0+
var ver_major = int.Parse(original_keras_version.Split('.')[0]);
var ver_minor = int.Parse(original_keras_version.Split('.')[1]);
if (ver_major < 2 || (ver_major == 2 && ver_minor < 5))
throw new ValueError("keras version should be 2.5.0 or later.");
}
if (Hdf5.AttributeExists(f, "backend"))
{
var (success, attr) = Hdf5.ReadStringAttributes(f, "backend", "");
if (success)
original_backend = attr.First();
}
var (success, attr) = Hdf5.ReadStringAttributes(f, "keras_version", "", true);
if (success)
original_keras_version = attr.First();
// keras version should be 2.5.0+
var ver_major = int.Parse(original_keras_version.Split('.')[0]);
var ver_minor = int.Parse(original_keras_version.Split('.')[1]);
if (ver_major < 2 || (ver_major == 2 && ver_minor < 5))
throw new ValueError("keras version should be 2.5.0 or later.");
(success, attr) = Hdf5.ReadStringAttributes(f, "backend", "", true);
if (success)
original_backend = attr.First();

var filtered_layers = new List<ILayer>();
foreach (var layer in layers)
@@ -137,7 +132,7 @@ namespace Tensorflow.Keras.Saving
var weight_names = load_attributes_from_hdf5_group(g, "weight_names");
foreach (var i_ in weight_names)
{
(bool success, Array result) = Hdf5.ReadDataset<float>(g, i_);
(success, Array result) = Hdf5.ReadDataset<float>(g, i_);
if (success)
weight_values.Add(np.array(result));
}
@@ -329,12 +324,10 @@ namespace Tensorflow.Keras.Saving

public static string[] load_attributes_from_hdf5_group(long group, string name)
{
if (Hdf5.AttributeExists(group, name))
{
var (success, attr) = Hdf5.ReadStringAttributes(group, name, "");
if (success)
return attr.ToArray();
}
var (success, attr) = Hdf5.ReadStringAttributes(group, name, "", true);
if (success)
return attr.ToArray();
return null;
}



+ 3
- 3
src/TensorFlowNET.Keras/Tensorflow.Keras.csproj View File

@@ -70,10 +70,10 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
</PropertyGroup>

<ItemGroup>
<PackageReference Include="HDF5-CSharp" Version="1.12.5" />
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.144" />
<PackageReference Include="HDF5-CSharp" Version="1.16.2" />
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.1" />
<PackageReference Include="SharpZipLib" Version="1.3.3" />
<PackageReference Include="SharpZipLib" Version="1.4.1" />
</ItemGroup>

<ItemGroup>


Loading…
Cancel
Save