Browse Source

return weights for load_weights.

tags/yolov3
Oceania2018 4 years ago
parent
commit
48d96f4afc
9 changed files with 40 additions and 207 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs
  3. +2
    -176
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  5. +8
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  6. +3
    -18
      src/TensorFlowNET.Core/Tensors/Tensor.Value.cs
  7. +3
    -0
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  8. +8
    -7
      src/TensorFlowNET.Keras/Engine/Model.Training.cs
  9. +13
    -3
      src/TensorFlowNET.Keras/Saving/hdf5_format.cs

+ 1
- 1
src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs View File

@@ -7,7 +7,7 @@ namespace Tensorflow.Eager
{
public partial class EagerTensor
{
public EagerTensor(SafeTensorHandleHandle handle) : base(IntPtr.Zero)
public EagerTensor(SafeTensorHandleHandle handle)
{
_id = ops.uid();
EagerTensorHandle = handle;


+ 1
- 1
src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs View File

@@ -63,7 +63,7 @@ namespace Tensorflow
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"AssignVariableOp", name,
null,
resource, value);


+ 2
- 176
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -265,182 +265,8 @@ namespace Tensorflow

private static unsafe NDArray fetchValue(IntPtr output)
{
NDArray ret;
using (var tensor = new Tensor(output))
{
var ndims = tensor.shape;
var srcAddress = c_api.TF_TensorData(output).ToInt64();

if (ndims.Length == 0)
{
switch (tensor.dtype)
{
case TF_DataType.TF_BOOL:
ret = NDArray.Scalar(*(bool*)srcAddress);
break;
case TF_DataType.TF_STRING:
using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize)))
ret = new NDArray(reader.ReadBytes().ToByteArray());
break;
case TF_DataType.TF_UINT8:
ret = NDArray.Scalar(*(byte*)srcAddress);
break;
case TF_DataType.TF_INT16:
ret = NDArray.Scalar(*(short*)srcAddress);
break;
case TF_DataType.TF_INT32:
ret = NDArray.Scalar(*(int*)srcAddress);
break;
case TF_DataType.TF_INT64:
ret = NDArray.Scalar(*(long*)srcAddress);
break;
case TF_DataType.TF_UINT16:
ret = NDArray.Scalar(*(ushort*)srcAddress);
break;
case TF_DataType.TF_UINT32:
ret = NDArray.Scalar(*(uint*)srcAddress);
break;
case TF_DataType.TF_UINT64:
ret = NDArray.Scalar(*(ulong*)srcAddress);
break;
case TF_DataType.TF_FLOAT:
ret = NDArray.Scalar(*(float*)srcAddress);
break;
case TF_DataType.TF_DOUBLE:
ret = NDArray.Scalar(*(double*)srcAddress);
break;
default:
throw new NotImplementedException("can't fetch output");
}
}
else
{
//var size = (long) tensor.size;
//var itemsize = (long) tensor.itemsize;
var bytesize = (long)tensor.bytesize;
var src = (void*)srcAddress;

#if _REGEN
#region Compute
switch (tensor.dtype)
{
%foreach except(supported_dtypes, "Char"),except(supported_dtypes_lowercase, "char"),except(supported_dtypes_TF_DataType,"TF_STRING")%
case TF_DataType.#3:
{
ret = new NDArray(NPTypeCode.#1, ndims, false);
System.Buffer.MemoryCopy(src, #(#3=="TF_STRING"|"(byte*)ret.Unsafe.Address + 8"|"ret.Unsafe.Address"), bytesize, bytesize);
break;
}
%
case TF_DataType.TF_STRING:
{
//TODO:! This is not the way to handle string[], it should be done with TF_DecodeString
using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize)))
ret = NDArray.FromString(reader.ReadString());
break;
}
default:
throw new NotSupportedException();
}
#endregion
#else

#region Compute

switch (tensor.dtype)
{
case TF_DataType.TF_BOOL:
{
ret = new NDArray(NPTypeCode.Boolean, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}

case TF_DataType.TF_UINT8:
{
ret = new NDArray(NPTypeCode.Byte, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}

case TF_DataType.TF_INT16:
{
ret = new NDArray(NPTypeCode.Int16, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}

case TF_DataType.TF_UINT16:
{
ret = new NDArray(NPTypeCode.UInt16, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}

case TF_DataType.TF_INT32:
{
ret = new NDArray(NPTypeCode.Int32, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}

case TF_DataType.TF_UINT32:
{
ret = new NDArray(NPTypeCode.UInt32, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}

case TF_DataType.TF_INT64:
{
ret = new NDArray(NPTypeCode.Int64, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}

case TF_DataType.TF_UINT64:
{
ret = new NDArray(NPTypeCode.UInt64, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}

case TF_DataType.TF_DOUBLE:
{
ret = new NDArray(NPTypeCode.Double, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}

case TF_DataType.TF_FLOAT:
{
ret = new NDArray(NPTypeCode.Single, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}

case TF_DataType.TF_STRING:
{
throw new NotImplementedException();
//TODO:! This is not the way to handle string[], it should be done with TF_DecodeString
#pragma warning disable CS0162 // Unreachable code detected
using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize)))
#pragma warning restore CS0162 // Unreachable code detected
ret = NDArray.FromString(reader.ReadString());
break;
}

default:
throw new NotSupportedException();
}

#endregion

#endif
}
}

return ret;
var tensor = new Tensor(output);
return tensor.numpy();
}

/// <summary>


+ 1
- 1
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -82,7 +82,7 @@ TensorFlow .NET v0.3x is focused on making more Keras API works</PackageReleaseN
<ItemGroup>
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="5.0.1" />
<PackageReference Include="NumSharp.Lite" Version="0.1.11" />
<PackageReference Include="NumSharp.Lite" Version="0.1.12" />
<PackageReference Include="Protobuf.Text" Version="0.4.0" />
<PackageReference Include="Serilog.Sinks.Console" Version="3.1.1" />
</ItemGroup>


+ 8
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -48,6 +48,11 @@ namespace Tensorflow

public IntPtr TensorDataPointer => _handle == IntPtr.Zero ? IntPtr.Zero : TF_TensorData(_handle);

public Tensor()
{

}

/// <summary>
/// Create a Tensor object from an existing TF handle
/// </summary>
@@ -56,6 +61,9 @@ namespace Tensorflow
{
_handle = handle;
//no need to set AllocationType = AllocationType.None;
#if TRACK_TENSOR_LIFE
print($"New Tensor 0x{_handle.ToString("x16")} {AllocationType} String Data: 0x{TensorDataPointer.ToString("x16")}");
#endif
}

public Tensor(int value)


+ 3
- 18
src/TensorFlowNET.Core/Tensors/Tensor.Value.cs View File

@@ -163,6 +163,9 @@ namespace Tensorflow
break;
case TF_DataType.TF_STRING:
return np.array(StringBytes()[0]);
case TF_DataType.TF_UINT8:
storage = new UnmanagedStorage(NPTypeCode.Byte);
break;
case TF_DataType.TF_INT32:
storage = new UnmanagedStorage(NPTypeCode.Int32);
break;
@@ -186,23 +189,6 @@ namespace Tensorflow
return new NDArray(storage);
}

/*protected unsafe NDArray GetScalar(TF_DataType dtype)
{
switch(dtype)
{
case TF_DataType.TF_STRING:
return (NDArray)StringData()[0];
case TF_DataType.TF_INT32:
return *(int*)buffer;
case TF_DataType.TF_FLOAT:
return *(float*)buffer;
case TF_DataType.TF_DOUBLE:
return *(double*)buffer;
default:
return BufferToArray();
}
}*/

/// <summary>
/// Copies the memory of current buffer onto newly allocated array.
/// </summary>
@@ -210,7 +196,6 @@ namespace Tensorflow
public unsafe byte[] BufferToArray()
{
// ReSharper disable once LocalVariableHidesMember
var bytesize = (long)this.bytesize;
var data = new byte[bytesize];
fixed (byte* dst = data)
System.Buffer.MemoryCopy(buffer.ToPointer(), dst, bytesize, bytesize);


+ 3
- 0
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -100,6 +100,9 @@ namespace Tensorflow
if (read_value)
return gen_resource_variable_ops.read_variable_op(handle, dtype);

if (assign_op == null)
return null;

return assign_op;
}



+ 8
- 7
src/TensorFlowNET.Keras/Engine/Model.Training.cs View File

@@ -3,13 +3,14 @@ using System.Collections.Generic;
using System.Text;
using HDF.PInvoke;
using HDF5CSharp;
using NumSharp;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Engine
{
public partial class Model
{
public void load_weights(string filepath, bool by_name = false, bool skip_mismatch = false, object options = null)
public List<(IVariableV1, NDArray)> load_weights(string filepath, bool by_name = false, bool skip_mismatch = false, object options = null)
{
long fileId = Hdf5.OpenFile(filepath, true);

@@ -17,20 +18,20 @@ namespace Tensorflow.Keras.Engine
bool lsuccess = Hdf5.GroupExists(fileId, "layer_names");

if (!lsuccess && msuccess)
{
fileId = H5G.open(fileId, "model_weights");
}
if (by_name)
{
//fdf5_format.load_weights_from_hdf5_group_by_name();
throw new NotImplementedException("");
}
else
{
hdf5_format.load_weights_from_hdf5_group(fileId, Layers);
var weights = hdf5_format.load_weights_from_hdf5_group(fileId, Layers);
Hdf5.CloseFile(fileId);
// return a reference to prevent GC collect Variable.
return weights;
}
Hdf5.CloseFile(fileId);
}

public void save_weights(string filepath, bool overwrite = true, string save_format = null, object options = null)
{
long fileId = Hdf5.CreateFile(filepath);


+ 13
- 3
src/TensorFlowNET.Keras/Saving/hdf5_format.cs View File

@@ -71,18 +71,20 @@ namespace Tensorflow.Keras.Saving
var target_class = layer.GetType().Name;
return weights;
}

public static void save_optimizer_weights_to_hdf5_group(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
{

}

public static void load_optimizer_weights_from_hdf5_group(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
{

}

public static void load_weights_from_hdf5_group(long f, List<ILayer> layers)
public static List<(IVariableV1, NDArray)> load_weights_from_hdf5_group(long f, List<ILayer> layers)
{
string original_keras_version = "2.4.0";
string original_keras_version = "2.5.0";
string original_backend = null;
if (Hdf5.AttributeExists(f, "keras_version"))
{
@@ -156,15 +158,19 @@ namespace Tensorflow.Keras.Saving
}

keras.backend.batch_set_value(weight_value_tuples);
return weight_value_tuples;
}

public static void toarrayf4(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
{

}

public static void load_weights_from_hdf5_group_by_name(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
{

}

public static void save_weights_to_hdf5_group(long f, List<ILayer> layers)
{
List<string> layerName=new List<string>();
@@ -260,8 +266,8 @@ namespace Tensorflow.Keras.Saving
WriteAttrs(f, getType,name, data);
}

}

private static void WriteDataset(long f, string name, Tensor data)
{
switch (data.dtype)
@@ -283,6 +289,7 @@ namespace Tensorflow.Keras.Saving
break;
}
}

private static void WriteAttrs(long f,string typename, string name, Array data)
{
switch (typename)
@@ -307,6 +314,7 @@ namespace Tensorflow.Keras.Saving
break;
}
}

private static List<List<object>> Split(Array list, int chunkSize)
{
var splitList = new List<List<object>>();
@@ -327,6 +335,7 @@ namespace Tensorflow.Keras.Saving

return splitList;
}

public static string[] load_attributes_from_hdf5_group(long group, string name)
{
if (Hdf5.AttributeExists(group, name))
@@ -337,6 +346,7 @@ namespace Tensorflow.Keras.Saving
}
return null;
}

public static void load_attributes_from_hdf5_group(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
{



Loading…
Cancel
Save