Browse Source

tf.data framework.

tags/v0.20
Oceania2018 5 years ago
parent
commit
3a220fa52d
32 changed files with 993 additions and 70 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/APIs/c_api.cs
  2. +11
    -0
      src/TensorFlowNET.Core/Binding.Util.cs
  3. +41
    -0
      src/TensorFlowNET.Core/Data/BatchDataset.cs
  4. +18
    -0
      src/TensorFlowNET.Core/Data/DatasetSource.cs
  5. +54
    -0
      src/TensorFlowNET.Core/Data/DatasetV2.cs
  6. +25
    -1
      src/TensorFlowNET.Core/Data/IDatasetV2.cs
  7. +29
    -0
      src/TensorFlowNET.Core/Data/PrefetchDataset.cs
  8. +24
    -0
      src/TensorFlowNET.Core/Data/RepeatDataset.cs
  9. +37
    -0
      src/TensorFlowNET.Core/Data/ShuffleDataset.cs
  10. +20
    -0
      src/TensorFlowNET.Core/Data/TakeDataset.cs
  11. +10
    -6
      src/TensorFlowNET.Core/Data/TensorSliceDataset.cs
  12. +21
    -0
      src/TensorFlowNET.Core/Data/UnaryDataset.cs
  13. +18
    -0
      src/TensorFlowNET.Core/Data/UnaryUnchangedStructureDataset.cs
  14. +1
    -0
      src/TensorFlowNET.Core/Data/Utils.cs
  15. +60
    -9
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
  16. +45
    -0
      src/TensorFlowNET.Core/Eager/EagerTensor.ToString.cs
  17. +0
    -33
      src/TensorFlowNET.Core/Eager/EagerTensor.cs
  18. +11
    -1
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  19. +31
    -0
      src/TensorFlowNET.Core/Framework/Models/DenseSpec.cs
  20. +31
    -0
      src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs
  21. +13
    -0
      src/TensorFlowNET.Core/Framework/Models/TypeSpec.cs
  22. +17
    -0
      src/TensorFlowNET.Core/Framework/random_seed.py.cs
  23. +2
    -2
      src/TensorFlowNET.Core/Keras/Optimizers/SGD.cs
  24. +178
    -0
      src/TensorFlowNET.Core/Operations/dataset_ops.cs
  25. +274
    -0
      src/TensorFlowNET.Core/Range.cs
  26. +1
    -1
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  27. +4
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs
  28. +3
    -4
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  29. +9
    -1
      src/TensorFlowNET.Core/Tensors/TensorShape.cs
  30. +2
    -0
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  31. +0
    -11
      src/TensorFlowNET.Core/Util/IPointerInputs.cs
  32. +2
    -0
      src/TensorFlowNET.Core/tensorflow.memory.cs

+ 1
- 1
src/TensorFlowNET.Core/APIs/c_api.cs View File

@@ -43,7 +43,7 @@ namespace Tensorflow
/// </summary>
public partial class c_api
{
public const string TensorFlowLibName = "tensorflow";
public const string TensorFlowLibName = @"D:\SciSharp\tensorflow-google\bazel-bin\tensorflow\tensorflow.dll";

public static string StringPiece(IntPtr handle)
{


+ 11
- 0
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -265,6 +265,17 @@ namespace Tensorflow
yield return (i, values[i]);
}

public static IEnumerable<(int, T)> enumerate<T>(IEnumerable<T> values, int start = 0)
{
int i = 0;
foreach(var val in values)
{
if (i < start)
continue;
yield return (i, val);
}
}

[DebuggerStepThrough]
public static Dictionary<string, object> ConvertToDict(object dyn)
{


+ 41
- 0
src/TensorFlowNET.Core/Data/BatchDataset.cs View File

@@ -0,0 +1,41 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Framework.Models;
using static Tensorflow.Binding;

namespace Tensorflow
{
/// <summary>
/// A `Dataset` that batches contiguous elements from its input.
/// </summary>
public class BatchDataset : UnaryDataset
{
Tensor _batch_size;
Tensor _drop_remainder;

public BatchDataset(IDatasetV2 input_dataset, int batch_size, bool drop_remainder = false) :
base(input_dataset)
{
_input_dataset = input_dataset;
_batch_size = tf.convert_to_tensor(batch_size, dtype: TF_DataType.TF_INT64, name: "batch_size");
_drop_remainder = tf.convert_to_tensor(drop_remainder, dtype: TF_DataType.TF_BOOL, name: "drop_remainder");
if (drop_remainder)
{
throw new NotImplementedException("");
}
else
{
_structure = input_dataset.element_spec.Select(x => x._batch(-1)).ToArray();
}

variant_tensor = ops.batch_dataset_v2(input_dataset.variant_tensor,
_batch_size,
_drop_remainder,
output_types,
output_shapes);
}
}
}

+ 18
- 0
src/TensorFlowNET.Core/Data/DatasetSource.cs View File

@@ -0,0 +1,18 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Framework.Models;

namespace Tensorflow
{
public class DatasetSource : DatasetV2
{
protected Tensor[] _tensors;

public DatasetSource()
{

}
}
}

+ 54
- 0
src/TensorFlowNET.Core/Data/DatasetV2.cs View File

@@ -0,0 +1,54 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Framework.Models;

namespace Tensorflow
{
/// <summary>
/// Abstract class representing a dataset with no inputs.
/// </summary>
public class DatasetV2 : IDatasetV2
{
protected dataset_ops ops = new dataset_ops();
public Tensor variant_tensor { get; set; }

public TensorSpec[] _structure { get; set; }

public TensorShape[] output_shapes => _structure.Select(x => x.shape).ToArray();
public TF_DataType[] output_types => _structure.Select(x => x.dtype).ToArray();
public TensorSpec[] element_spec => _structure;

public IDatasetV2 take(int count = -1)
=> new TakeDataset(this, count: count);

public IDatasetV2 batch(int batch_size, bool drop_remainder = false)
=> new BatchDataset(this, batch_size, drop_remainder: drop_remainder);

public IDatasetV2 prefetch(int buffer_size = -1, int? slack_period = null)
=> new PrefetchDataset(this, buffer_size: buffer_size, slack_period: slack_period);

public IDatasetV2 repeat(int count = -1)
=> new RepeatDataset(this, count: count);

public IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true)
=> new ShuffleDataset(this, buffer_size, seed: seed, reshuffle_each_iteration: reshuffle_each_iteration);
public override string ToString()
=> $"{GetType().Name} shapes: ({_structure[0].shape}, {_structure[1].shape}), types: (tf.{_structure[0].dtype.as_numpy_name()}, tf.{_structure[1].dtype.as_numpy_name()})";

public IEnumerator<(Tensor, Tensor)> GetEnumerator()
{
throw new NotImplementedException();
}

IEnumerator IEnumerable.GetEnumerator()
{
return this.GetEnumerator();
}
}
}

+ 25
- 1
src/TensorFlowNET.Core/Data/IDatasetV2.cs View File

@@ -1,11 +1,35 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Framework.Models;

namespace Tensorflow
{
public interface IDatasetV2
public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)>
{
Tensor variant_tensor { get; set; }

TensorShape[] output_shapes { get; }

TF_DataType[] output_types { get; }

TensorSpec[] element_spec { get; }

TensorSpec[] _structure { get; set; }

/// <summary>
///
/// </summary>
/// <param name="count"></param>
/// <returns></returns>
IDatasetV2 repeat(int count = -1);

IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true);

IDatasetV2 batch(int batch_size, bool drop_remainder = false);

IDatasetV2 prefetch(int buffer_size = -1, int? slack_period = null);

IDatasetV2 take(int count);
}
}

+ 29
- 0
src/TensorFlowNET.Core/Data/PrefetchDataset.cs View File

@@ -0,0 +1,29 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow
{
/// <summary>
/// Creates a `Dataset` that prefetches elements from this dataset.
/// </summary>
public class PrefetchDataset : UnaryUnchangedStructureDataset
{
Tensor _buffer_size;

public PrefetchDataset(IDatasetV2 input_dataset,
long buffer_size = -1,
int? slack_period = null) :
base(input_dataset)
{
_buffer_size = tf.convert_to_tensor(buffer_size, dtype: TF_DataType.TF_INT64, name: "buffer_size");

variant_tensor = ops.prefetch_dataset(input_dataset.variant_tensor,
_buffer_size,
input_dataset.output_types,
input_dataset.output_shapes,
slack_period: slack_period);
}
}
}

+ 24
- 0
src/TensorFlowNET.Core/Data/RepeatDataset.cs View File

@@ -0,0 +1,24 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
/// <summary>
/// A `Dataset` that repeats its input several times.
/// </summary>
public class RepeatDataset : UnaryUnchangedStructureDataset
{
Tensor _count;

public RepeatDataset(IDatasetV2 input_dataset, int count = -1) :
base(input_dataset)
{
_count = constant_op.constant(count, dtype: TF_DataType.TF_INT64, name: "count");
variant_tensor = ops.repeat_dataset(input_dataset.variant_tensor,
_count,
input_dataset.output_types,
input_dataset.output_shapes);
}
}
}

+ 37
- 0
src/TensorFlowNET.Core/Data/ShuffleDataset.cs View File

@@ -0,0 +1,37 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow
{
/// <summary>
/// Randomly shuffles the elements of this dataset.
/// </summary>
public class ShuffleDataset : UnaryUnchangedStructureDataset
{
Tensor _buffer_size;
Tensor _seed;
Tensor _seed2;
bool _reshuffle_each_iteration;

public ShuffleDataset(IDatasetV2 input_dataset,
long buffer_size,
int? seed = null,
bool reshuffle_each_iteration = true) :
base(input_dataset)
{
_buffer_size = tf.convert_to_tensor(buffer_size, dtype: TF_DataType.TF_INT64, name: "buffer_size");
(_seed, _seed2) = random_seed.get_seed_tensor(seed);
_reshuffle_each_iteration = reshuffle_each_iteration;
var seed_generator = ops.dummy_seed_generator();
if (tf.context.executing_eagerly())
variant_tensor = ops.shuffle_dataset_v3(input_dataset.variant_tensor, _buffer_size,
_seed, _seed2, seed_generator,
output_types, output_shapes,
reshuffle_each_iteration: _reshuffle_each_iteration);
else
throw new NotImplementedException("");
}
}
}

+ 20
- 0
src/TensorFlowNET.Core/Data/TakeDataset.cs View File

@@ -0,0 +1,20 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow
{
public class TakeDataset : UnaryUnchangedStructureDataset
{
Tensor _count;

public TakeDataset(IDatasetV2 input_dataset, int count) :
base(input_dataset)
{
_count = tf.convert_to_tensor(count, dtype: dtypes.int64, name: "count");
variant_tensor = ops.take_dataset(input_dataset.variant_tensor, _count,
output_types, output_shapes);
}
}
}

+ 10
- 6
src/TensorFlowNET.Core/Data/TensorSliceDataset.cs View File

@@ -1,19 +1,23 @@
using NumSharp;
using NumSharp.Utilities;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Framework.Models;
using static Tensorflow.Binding;

namespace Tensorflow
{
public class TensorSliceDataset : IDatasetV2
public class TensorSliceDataset : DatasetSource
{
NDArray features;
NDArray labels;

public TensorSliceDataset(NDArray features, NDArray labels)
{
this.features = features;
this.labels = labels;
_tensors = new[] { tf.convert_to_tensor(features), tf.convert_to_tensor(labels) };
var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray();
_structure = batched_spec.Select(x => x._unbatch()).ToArray();
variant_tensor = ops.tensor_slice_dataset(_tensors, output_shapes);
}
}
}

+ 21
- 0
src/TensorFlowNET.Core/Data/UnaryDataset.cs View File

@@ -0,0 +1,21 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Framework.Models;

namespace Tensorflow
{
/// <summary>
/// Abstract class representing a dataset with one input.
/// </summary>
public class UnaryDataset : DatasetV2
{
protected IDatasetV2 _input_dataset;

public UnaryDataset(IDatasetV2 input_dataset)
{
_input_dataset = input_dataset;
_structure = input_dataset._structure;
}
}
}

+ 18
- 0
src/TensorFlowNET.Core/Data/UnaryUnchangedStructureDataset.cs View File

@@ -0,0 +1,18 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
/// <summary>
/// Represents a unary dataset with the same input and output structure.
/// </summary>
public class UnaryUnchangedStructureDataset : UnaryDataset
{
public UnaryUnchangedStructureDataset(IDatasetV2 input_dataset) :
base(input_dataset)
{
}
}
}

+ 1
- 0
src/TensorFlowNET.Core/Data/Utils.cs View File

@@ -6,6 +6,7 @@ using System.Net;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Tensorflow.Framework.Models;

namespace Tensorflow
{


+ 60
- 9
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs View File

@@ -7,6 +7,7 @@ using Google.Protobuf.WellKnownTypes;
using System.Threading;
using Tensorflow.Util;
using System.Runtime.InteropServices.ComTypes;
using System.Runtime.InteropServices;

namespace Tensorflow.Eager
{
@@ -73,10 +74,11 @@ namespace Tensorflow.Eager
// Add inferred attrs and inputs.
for (int i = 0; i < op_def.InputArg.Count; i++)
{
var input = args[kFastPathExecuteInputStartIndex + i];
var input_arg = op_def.InputArg[i];
if (!string.IsNullOrEmpty(input_arg.NumberAttr))
{
int len = (args[kFastPathExecuteInputStartIndex + i] as object[]).Length;
int len = (input as object[]).Length;
c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, len);
if (op_exec_info.run_callbacks)
{
@@ -102,7 +104,31 @@ namespace Tensorflow.Eager
}
else if (!string.IsNullOrEmpty(input_arg.TypeListAttr))
{
throw new NotImplementedException("");
var attr_name = input_arg.TypeListAttr;
var fast_input_array = input as object[];
var len = fast_input_array.Length;
var attr_values = new TF_DataType[len];

for (var j = 0; j < len; j++)
{
var eager_tensor = ops.convert_to_tensor(fast_input_array[j]);
attr_values[j] = eager_tensor.dtype;

c_api.TFE_OpAddInput(op, eager_tensor.EagerTensorHandle, status.Handle);

if (op_exec_info.run_callbacks)
{
flattened_inputs.Add(eager_tensor);
}
}

if (op_exec_info.run_callbacks)
{
flattened_attrs.Add(attr_name);
flattened_attrs.Add(attr_values);
}
c_api.TFE_OpSetAttrTypeList(op, attr_name, attr_values, attr_values.Length);
attr_list_sizes[attr_name] = len;
}
else
{
@@ -206,7 +232,7 @@ namespace Tensorflow.Eager
break;
default:
var tensor = tf.convert_to_tensor(inputs);
input_handle = (tensor as EagerTensor).EagerTensorHandle;
input_handle = tensor.EagerTensorHandle;
break;
}

@@ -237,7 +263,7 @@ namespace Tensorflow.Eager
var type = c_api.TFE_OpGetAttrType(op, key, ref is_list, status.Handle);
if (!status.ok()) return;
if (is_list != 0)
SetOpAttrList(tf.context, op, key, value, type, null, status);
SetOpAttrList(tf.context, op, key, value as object[], type, null, status);
else
SetOpAttrScalar(tf.context, op, key, value, type, null, status);
status.Check(true);
@@ -282,20 +308,45 @@ namespace Tensorflow.Eager
else
{
if (is_list != 0)
#pragma warning disable CS0642 // Possible mistaken empty statement
;// SetOpAttrList
#pragma warning restore CS0642 // Possible mistaken empty statement
SetOpAttrList(ctx, op, attr_name, attr_value, type, attr_list_sizes, status);
else
SetOpAttrScalar(ctx, op, attr_name, attr_value, type, attr_list_sizes, status);
}
}

bool SetOpAttrList(Context ctx, SafeOpHandle op,
string key, object value, TF_AttrType type,
string key, object values, TF_AttrType type,
Dictionary<string, long> attr_list_sizes,
Status status)
{
return false;
if(type == TF_AttrType.TF_ATTR_SHAPE && values is TensorShape[] values1)
{
// Make one pass through the input counting the total number of
// dims across all the input lists.
var num_values = values1.Length;
attr_list_sizes[key] = num_values;
var dims = new IntPtr[num_values];
var num_dims = values1.Select(x => x.ndim).ToArray();

for (int i = 0; i < num_values; ++i)
{
dims[i] = Marshal.AllocHGlobal(sizeof(long) * values1[i].ndim);
tf.memcpy(dims[i], values1[i].dims.Select(x => (long)x).ToArray(), values1[i].ndim);
}

c_api.TFE_OpSetAttrShapeList(op, key, dims, num_dims, num_values, status.Handle);
Array.ForEach(dims, x => Marshal.FreeHGlobal(x));
}
else if(type == TF_AttrType.TF_ATTR_TYPE && values is TF_DataType[] values2)
{
c_api.TFE_OpSetAttrTypeList(op, key, values2, values2.Length);
}
else
{
throw new NotImplementedException("");
}

return true;
}

bool SetOpAttrScalar(Context ctx, SafeOpHandle op,


+ 45
- 0
src/TensorFlowNET.Core/Eager/EagerTensor.ToString.cs View File

@@ -0,0 +1,45 @@
using NumSharp;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow.Eager
{
public partial class EagerTensor
{
public override string ToString()
{
switch (rank)
{
case -1:
return $"tf.Tensor: shape={TensorShape}, dtype={dtype.as_numpy_name()}, numpy={GetFormattedString(dtype, numpy())}";
case 0:
return $"tf.Tensor: shape={TensorShape}, dtype={dtype.as_numpy_name()}, numpy={GetFormattedString(dtype, numpy())}";
default:
return $"tf.Tensor: shape={TensorShape}, dtype={dtype.as_numpy_name()}, numpy={GetFormattedString(dtype, numpy())}";
}
}

public static string GetFormattedString(TF_DataType dtype, NDArray nd)
{
if (nd.size == 0)
return "[]";

switch (dtype)
{
case TF_DataType.TF_STRING:
return string.Join(string.Empty, nd.ToArray<byte>()
.Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString()));
case TF_DataType.TF_BOOL:
return (nd.GetByte(0) > 0).ToString();
case TF_DataType.TF_VARIANT:
case TF_DataType.TF_RESOURCE:
return "<unprintable>";
default:
return nd.ToString();
}
}
}
}

+ 0
- 33
src/TensorFlowNET.Core/Eager/EagerTensor.cs View File

@@ -9,7 +9,6 @@ namespace Tensorflow.Eager
{
public partial class EagerTensor : Tensor
{
public IntPtr EagerTensorHandle;
public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(EagerTensorHandle, tf.status.Handle));

public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.status.Handle);
@@ -28,37 +27,5 @@ namespace Tensorflow.Eager
dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, tf.status.Handle);
return dims;
}

public override string ToString()
{
switch (rank)
{
case -1:
return $"tf.Tensor: shape=<unknown>, dtype={dtype.as_numpy_name()}, numpy={GetFormattedString(dtype, numpy())}";
case 0:
return $"tf.Tensor: shape=(), dtype={dtype.as_numpy_name()}, numpy={GetFormattedString(dtype, numpy())}";
default:
return $"tf.Tensor: shape=({string.Join(",", shape)}), dtype={dtype.as_numpy_name()}, numpy={GetFormattedString(dtype, numpy())}";
}
}

public static string GetFormattedString(TF_DataType dtype, NDArray nd)
{
if (nd.size == 0)
return "[]";

switch (dtype)
{
case TF_DataType.TF_STRING:
return string.Join(string.Empty, nd.ToArray<byte>()
.Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString()));
case TF_DataType.TF_BOOL:
return (nd.GetByte(0) > 0).ToString();
case TF_DataType.TF_RESOURCE:
return "<unprintable>";
default:
return nd.ToString();
}
}
}
}

+ 11
- 1
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -1,4 +1,5 @@
using System;
using Google.Protobuf;
using System;
using System.Runtime.InteropServices;
using Tensorflow.Device;
using Tensorflow.Eager;
@@ -156,6 +157,9 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrShape(SafeOpHandle op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status);

[DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrShapeList(SafeOpHandle op, string attr_name, IntPtr[] dims, int[] num_dims, int num_values, SafeStatusHandle out_status);

[DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrBool(SafeOpHandle op, string attr_name, bool value);

@@ -168,6 +172,12 @@ namespace Tensorflow
/// <param name="length">size_t</param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrString(SafeOpHandle op, string attr_name, string value, uint length);
[DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrTypeList(SafeOpHandle op, string attr_name, TF_DataType[] values, int num_values);

[DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrValueProto(SafeOpHandle op, string attr_name, IMessage[] proto, int proto_len, SafeStatusHandle status);

/// <summary>
///


+ 31
- 0
src/TensorFlowNET.Core/Framework/Models/DenseSpec.cs View File

@@ -0,0 +1,31 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Framework.Models
{
/// <summary>
/// Describes a dense object with shape, dtype, and name.
/// </summary>
public class DenseSpec : TypeSpec
{
protected TensorShape _shape;
public TensorShape shape => _shape;

protected TF_DataType _dtype;
public TF_DataType dtype => _dtype;

protected string _name;
public string name => _name;

public DenseSpec(int[] shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
{
_shape = new TensorShape(shape);
_dtype = dtype;
_name = name;
}

public override string ToString()
=> $"shape={_shape}, dtype={_dtype.as_numpy_name()}, name={_name}";
}
}

+ 31
- 0
src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs View File

@@ -0,0 +1,31 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Tensorflow.Framework.Models
{
public class TensorSpec : DenseSpec
{
public TensorSpec(int[] shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) :
base(shape, dtype, name)
{

}

public TensorSpec _unbatch()
{
if (_shape.ndim == 0)
throw new ValueError("Unbatching a tensor is only supported for rank >= 1");

return new TensorSpec(_shape.dims[1..], _dtype);
}

public TensorSpec _batch(int dim = -1)
{
var shapes = shape.dims.ToList();
shapes.Insert(0, dim);
return new TensorSpec(shapes.ToArray(), _dtype);
}
}
}

+ 13
- 0
src/TensorFlowNET.Core/Framework/Models/TypeSpec.cs View File

@@ -0,0 +1,13 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Framework.Models
{
/// <summary>
/// Specifies a TensorFlow value type.
/// </summary>
public class TypeSpec
{
}
}

+ 17
- 0
src/TensorFlowNET.Core/Framework/random_seed.py.cs View File

@@ -27,5 +27,22 @@ namespace Tensorflow
else
return (null, null);
}

public static (Tensor, Tensor) get_seed_tensor(int? op_seed = null)
{
var (seed, seed2) = get_seed(op_seed);
Tensor _seed, _seed2;
if (seed is null)
_seed = constant_op.constant(0, dtype: TF_DataType.TF_INT64, name: "seed");
else
_seed = constant_op.constant(seed.Value, dtype: TF_DataType.TF_INT64, name: "seed");

if (seed2 is null)
_seed2 = constant_op.constant(0, dtype: TF_DataType.TF_INT64, name: "seed2");
else
_seed2 = constant_op.constant(seed2.Value, dtype: TF_DataType.TF_INT64, name: "seed2");

return (_seed, _seed2);
}
}
}

+ 2
- 2
src/TensorFlowNET.Core/Keras/Optimizers/SGD.cs View File

@@ -48,8 +48,8 @@ namespace Tensorflow.Keras.Optimizers
}
var device_dtype = _apply_state.Keys.FirstOrDefault(x => x.Device == var.Device && x.DType == var.dtype.as_base_dtype());

return gen_training_ops.resource_apply_gradient_descent(var.Handle as EagerTensor,
_apply_state[device_dtype]["lr_t"] as EagerTensor,
return gen_training_ops.resource_apply_gradient_descent(var.Handle,
_apply_state[device_dtype]["lr_t"],
grad,
use_locking: _use_locking);
}


+ 178
- 0
src/TensorFlowNET.Core/Operations/dataset_ops.cs View File

@@ -0,0 +1,178 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow
{
public class dataset_ops
{
/// <summary>
/// Creates a dataset that emits each dim-0 slice of `components` once.
/// </summary>
/// <param name="components"></param>
/// <param name="output_shapes"></param>
/// <param name="name"></param>
/// <returns></returns>
public Tensor tensor_slice_dataset(Tensor[] components, TensorShape[] output_shapes, string name = null)
{
if (tf.context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name,
"TensorSliceDataset", name,
null,
new object[]
{
components,
"output_shapes", output_shapes
});
return results[0];
}

throw new NotImplementedException("");
}

public Tensor repeat_dataset(Tensor input_dataset, Tensor count, TF_DataType[] output_types, TensorShape[] output_shapes, string name = null)
{
if (tf.context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name,
"RepeatDataset", name,
null,
input_dataset, count,
"output_types", output_types,
"output_shapes", output_shapes);
return results[0];
}

throw new NotImplementedException("");
}

public Tensor shuffle_dataset_v3(Tensor input_dataset, Tensor buffer_size,
Tensor seed, Tensor seed2, Tensor seed_generator,
TF_DataType[] output_types, TensorShape[] output_shapes,
bool reshuffle_each_iteration = true,
string name = null)
{
if (tf.context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name,
"ShuffleDatasetV3", name,
null,
input_dataset, buffer_size,
seed, seed2, seed_generator,
"reshuffle_each_iteration", reshuffle_each_iteration,
"output_types", output_types,
"output_shapes", output_shapes);
return results[0];
}

throw new NotImplementedException("");
}

public Tensor dummy_seed_generator(string name = null)
{
if (tf.context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name,
"DummySeedGenerator", name,
null);
return results[0];
}

throw new NotImplementedException("");
}

/// <summary>
/// Creates a dataset that batches `batch_size` elements from `input_dataset`.
/// </summary>
/// <param name="input_dataset"></param>
/// <param name="buffer_size"></param>
/// <param name="drop_remainder"></param>
/// <param name="output_types"></param>
/// <param name="output_shapes"></param>
/// <param name="parallel_copy"></param>
/// <param name="name"></param>
/// <returns></returns>
public Tensor batch_dataset_v2(Tensor input_dataset, Tensor buffer_size,
Tensor drop_remainder,
TF_DataType[] output_types, TensorShape[] output_shapes,
bool parallel_copy = false,
string name = null)
{
if (tf.context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name,
"BatchDatasetV2", name,
null,
input_dataset, buffer_size, drop_remainder,
"parallel_copy", parallel_copy,
"output_types", output_types,
"output_shapes", output_shapes);
return results[0];
}

throw new NotImplementedException("");
}

/// <summary>
/// Creates a dataset that asynchronously prefetches elements from `input_dataset`.
/// </summary>
/// <param name="input_dataset"></param>
/// <param name="buffer_size"></param>
/// <param name="output_types"></param>
/// <param name="output_shapes"></param>
/// <param name="slack_period"></param>
/// <param name="legacy_autotune"></param>
/// <param name="name"></param>
/// <returns></returns>
public Tensor prefetch_dataset(Tensor input_dataset, Tensor buffer_size,
TF_DataType[] output_types, TensorShape[] output_shapes,
int? slack_period = 0,
bool legacy_autotune = true,
string name = null)
{
if (tf.context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name,
"PrefetchDataset", name,
null,
input_dataset, buffer_size,
"output_types", output_types,
"output_shapes", output_shapes,
"slack_period", slack_period,
"legacy_autotune", legacy_autotune);
return results[0];
}

throw new NotImplementedException("");
}

/// <summary>
/// Creates a dataset that contains `count` elements from the `input_dataset`.
/// </summary>
/// <param name="input_dataset"></param>
/// <param name="count"></param>
/// <param name="output_types"></param>
/// <param name="output_shapes"></param>
/// <param name="name"></param>
/// <returns></returns>
public Tensor take_dataset(Tensor input_dataset, Tensor count,
TF_DataType[] output_types, TensorShape[] output_shapes,
string name = null)
{
if (tf.context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name,
"TakeDataset", name,
null,
input_dataset, count,
"output_types", output_types,
"output_shapes", output_shapes);
return results[0];
}

throw new NotImplementedException("");
}
}
}

+ 274
- 0
src/TensorFlowNET.Core/Range.cs View File

@@ -0,0 +1,274 @@
// https://github.com/dotnet/corefx/blob/1597b894a2e9cac668ce6e484506eca778a85197/src/Common/src/CoreLib/System/Index.cs
// https://github.com/dotnet/corefx/blob/1597b894a2e9cac668ce6e484506eca778a85197/src/Common/src/CoreLib/System/Range.cs

using System.Runtime.CompilerServices;

namespace System
{
/// <summary>Represent a type can be used to index a collection either from the start or the end.</summary>
/// <remarks>
/// Index is used by the C# compiler to support the new index syntax
/// <code>
/// int[] someArray = new int[5] { 1, 2, 3, 4, 5 } ;
/// int lastElement = someArray[^1]; // lastElement = 5
/// </code>
/// </remarks>
internal readonly struct Index : IEquatable<Index>
{
private readonly int _value;

/// <summary>Construct an Index using a value and indicating if the index is from the start or from the end.</summary>
/// <param name="value">The index value. it has to be zero or positive number.</param>
/// <param name="fromEnd">Indicating if the index is from the start or from the end.</param>
/// <remarks>
/// If the Index constructed from the end, index value 1 means pointing at the last element and index value 0 means pointing at beyond last element.
/// </remarks>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public Index(int value, bool fromEnd = false)
{
if (value < 0)
{
throw new ArgumentOutOfRangeException(nameof(value), "value must be non-negative");
}

if (fromEnd)
_value = ~value;
else
_value = value;
}

// The following private constructors mainly created for perf reason to avoid the checks
private Index(int value)
{
_value = value;
}

/// <summary>Create an Index pointing at first element.</summary>
public static Index Start => new Index(0);

/// <summary>Create an Index pointing at beyond last element.</summary>
public static Index End => new Index(~0);

/// <summary>Create an Index from the start at the position indicated by the value.</summary>
/// <param name="value">The index value from the start.</param>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Index FromStart(int value)
{
if (value < 0)
{
throw new ArgumentOutOfRangeException(nameof(value), "value must be non-negative");
}

return new Index(value);
}

/// <summary>Create an Index from the end at the position indicated by the value.</summary>
/// <param name="value">The index value from the end.</param>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Index FromEnd(int value)
{
if (value < 0)
{
throw new ArgumentOutOfRangeException(nameof(value), "value must be non-negative");
}

return new Index(~value);
}

/// <summary>Returns the index value.</summary>
public int Value
{
get
{
if (_value < 0)
{
return ~_value;
}
else
{
return _value;
}
}
}

/// <summary>Indicates whether the index is from the start or the end.</summary>
public bool IsFromEnd => _value < 0;

/// <summary>Calculate the offset from the start using the giving collection length.</summary>
/// <param name="length">The length of the collection that the Index will be used with. length has to be a positive value</param>
/// <remarks>
/// For performance reason, we don't validate the input length parameter and the returned offset value against negative values.
/// we don't validate either the returned offset is greater than the input length.
/// It is expected Index will be used with collections which always have non negative length/count. If the returned offset is negative and
/// then used to index a collection will get out of range exception which will be same affect as the validation.
/// </remarks>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public int GetOffset(int length)
{
var offset = _value;
if (IsFromEnd)
{
// offset = length - (~value)
// offset = length + (~(~value) + 1)
// offset = length + value + 1

offset += length + 1;
}
return offset;
}

/// <summary>Indicates whether the current Index object is equal to another object of the same type.</summary>
/// <param name="value">An object to compare with this object</param>
public override bool Equals(object? value) => value is Index && _value == ((Index)value)._value;

/// <summary>Indicates whether the current Index object is equal to another Index object.</summary>
/// <param name="other">An object to compare with this object</param>
public bool Equals(Index other) => _value == other._value;

/// <summary>Returns the hash code for this instance.</summary>
public override int GetHashCode() => _value;

/// <summary>Converts integer number to an Index.</summary>
public static implicit operator Index(int value) => FromStart(value);

/// <summary>Converts the value of the current Index object to its equivalent string representation.</summary>
public override string ToString()
{
if (IsFromEnd)
return "^" + ((uint)Value).ToString();

return ((uint)Value).ToString();
}
}

/// <summary>Represent a range has start and end indexes.</summary>
/// <remarks>
/// Range is used by the C# compiler to support the range syntax.
/// <code>
/// int[] someArray = new int[5] { 1, 2, 3, 4, 5 };
/// int[] subArray1 = someArray[0..2]; // { 1, 2 }
/// int[] subArray2 = someArray[1..^0]; // { 2, 3, 4, 5 }
/// </code>
/// </remarks>
internal readonly struct Range : IEquatable<Range>
{
/// <summary>Represent the inclusive start index of the Range.</summary>
public Index Start { get; }

/// <summary>Represent the exclusive end index of the Range.</summary>
public Index End { get; }

/// <summary>Construct a Range object using the start and end indexes.</summary>
/// <param name="start">Represent the inclusive start index of the range.</param>
/// <param name="end">Represent the exclusive end index of the range.</param>
public Range(Index start, Index end)
{
Start = start;
End = end;
}

/// <summary>Indicates whether the current Range object is equal to another object of the same type.</summary>
/// <param name="value">An object to compare with this object</param>
public override bool Equals(object? value) =>
value is Range r &&
r.Start.Equals(Start) &&
r.End.Equals(End);

/// <summary>Indicates whether the current Range object is equal to another Range object.</summary>
/// <param name="other">An object to compare with this object</param>
public bool Equals(Range other) => other.Start.Equals(Start) && other.End.Equals(End);

/// <summary>Returns the hash code for this instance.</summary>
public override int GetHashCode()
{
return Start.GetHashCode() * 31 + End.GetHashCode();
}

/// <summary>Converts the value of the current Range object to its equivalent string representation.</summary>
public override string ToString()
{
return Start + ".." + End;
}

/// <summary>Create a Range object starting from start index to the end of the collection.</summary>
public static Range StartAt(Index start) => new Range(start, Index.End);

/// <summary>Create a Range object starting from first element in the collection to the end Index.</summary>
public static Range EndAt(Index end) => new Range(Index.Start, end);

/// <summary>Create a Range object starting from first element to the end.</summary>
public static Range All => new Range(Index.Start, Index.End);

/// <summary>Calculate the start offset and length of range object using a collection length.</summary>
/// <param name="length">The length of the collection that the range will be used with. length has to be a positive value.</param>
/// <remarks>
/// For performance reason, we don't validate the input length parameter against negative values.
/// It is expected Range will be used with collections which always have non negative length/count.
/// We validate the range is inside the length scope though.
/// </remarks>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public (int Offset, int Length) GetOffsetAndLength(int length)
{
int start;
var startIndex = Start;
if (startIndex.IsFromEnd)
start = length - startIndex.Value;
else
start = startIndex.Value;

int end;
var endIndex = End;
if (endIndex.IsFromEnd)
end = length - endIndex.Value;
else
end = endIndex.Value;

if ((uint)end > (uint)length || (uint)start > (uint)end)
{
throw new ArgumentOutOfRangeException(nameof(length));
}

return (start, end - start);
}
}
}

namespace System.Runtime.CompilerServices
{
internal static class RuntimeHelpers
{
/// <summary>
/// Slices the specified array using the specified range.
/// </summary>
public static T[] GetSubArray<T>(T[] array, Range range)
{
if (array == null)
{
throw new ArgumentNullException(nameof(array));
}

(int offset, int length) = range.GetOffsetAndLength(array.Length);

if (default(T) != null || typeof(T[]) == array.GetType())
{
// We know the type of the array to be exactly T[].

if (length == 0)
{
return Array.Empty<T>();
}

var dest = new T[length];
Array.Copy(array, offset, dest, 0, length);
return dest;
}
else
{
// The array is actually a U[] where U:T.
var dest = (T[])Array.CreateInstance(array.GetType().GetElementType(), length);
Array.Copy(array, offset, dest, 0, length);
return dest;
}
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -5,7 +5,7 @@
<AssemblyName>TensorFlow.NET</AssemblyName>
<RootNamespace>Tensorflow</RootNamespace>
<TargetTensorFlow>2.2.0</TargetTensorFlow>
<Version>0.20.0-preview2</Version>
<Version>0.20.0-preview3</Version>
<LangVersion>8.0</LangVersion>
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
<Company>SciSharp STACK</Company>


+ 4
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs View File

@@ -22,6 +22,7 @@ using System.Runtime.CompilerServices;
using System.Text;
using NumSharp.Utilities;
using static Tensorflow.Binding;
using Tensorflow.Framework.Models;

namespace Tensorflow
{
@@ -395,5 +396,8 @@ namespace Tensorflow
}
}
}

public TensorSpec ToTensorSpec()
=> new TensorSpec(shape, dtype, name);
}
}

+ 3
- 4
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -38,8 +38,7 @@ namespace Tensorflow
_TensorLike,
ITensorOrTensorArray,
IPackable<Tensor>,
ICanBeFlattened,
IPointerInputs
ICanBeFlattened
{
protected long _id;
private readonly Operation _op;
@@ -93,9 +92,9 @@ namespace Tensorflow
public object Tag { get; set; }

/// <summary>
/// Associated resource variable
/// TFE_TensorHandle
/// </summary>
public ResourceVariable ResourceVar { get; set; }
public IntPtr EagerTensorHandle { get; set; }

/// <summary>
/// Returns the shape of a tensor.


+ 9
- 1
src/TensorFlowNET.Core/Tensors/TensorShape.cs View File

@@ -254,7 +254,15 @@ namespace Tensorflow

public override string ToString()
{
return shape.ToString();
switch (rank)
{
case -1:
return $"<unknown>";
case 0:
return $"()";
default:
return $"{string.Join(",", shape).Replace("-1", "None")}";
}
}
}
}

+ 2
- 0
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -201,9 +201,11 @@ namespace Tensorflow
TF_DataType.TF_STRING => "string",
TF_DataType.TF_UINT8 => "uint8",
TF_DataType.TF_INT32 => "int32",
TF_DataType.TF_INT64 => "int64",
TF_DataType.TF_FLOAT => "float32",
TF_DataType.TF_BOOL => "bool",
TF_DataType.TF_RESOURCE => "resource",
TF_DataType.TF_VARIANT => "variant",
_ => type.ToString()
};



+ 0
- 11
src/TensorFlowNET.Core/Util/IPointerInputs.cs View File

@@ -1,11 +0,0 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public interface IPointerInputs
{
public IntPtr ToPointer();
}
}

+ 2
- 0
src/TensorFlowNET.Core/tensorflow.memory.cs View File

@@ -49,6 +49,8 @@ namespace Tensorflow
{
if (src.Length == 0) return;

size = size * sizeof(T);

fixed (void* p = &src[0])
System.Buffer.MemoryCopy(p, dst.ToPointer(), size, size);
}


Loading…
Cancel
Save