| @@ -43,7 +43,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public partial class c_api | public partial class c_api | ||||
| { | { | ||||
| public const string TensorFlowLibName = "tensorflow"; | |||||
| public const string TensorFlowLibName = @"D:\SciSharp\tensorflow-google\bazel-bin\tensorflow\tensorflow.dll"; | |||||
| public static string StringPiece(IntPtr handle) | public static string StringPiece(IntPtr handle) | ||||
| { | { | ||||
| @@ -265,6 +265,17 @@ namespace Tensorflow | |||||
| yield return (i, values[i]); | yield return (i, values[i]); | ||||
| } | } | ||||
| public static IEnumerable<(int, T)> enumerate<T>(IEnumerable<T> values, int start = 0) | |||||
| { | |||||
| int i = 0; | |||||
| foreach(var val in values) | |||||
| { | |||||
| if (i < start) | |||||
| continue; | |||||
| yield return (i, val); | |||||
| } | |||||
| } | |||||
| [DebuggerStepThrough] | [DebuggerStepThrough] | ||||
| public static Dictionary<string, object> ConvertToDict(object dyn) | public static Dictionary<string, object> ConvertToDict(object dyn) | ||||
| { | { | ||||
| @@ -0,0 +1,41 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Tensorflow.Framework.Models; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | |||||
| { | |||||
| /// <summary> | |||||
| /// A `Dataset` that batches contiguous elements from its input. | |||||
| /// </summary> | |||||
| public class BatchDataset : UnaryDataset | |||||
| { | |||||
| Tensor _batch_size; | |||||
| Tensor _drop_remainder; | |||||
| public BatchDataset(IDatasetV2 input_dataset, int batch_size, bool drop_remainder = false) : | |||||
| base(input_dataset) | |||||
| { | |||||
| _input_dataset = input_dataset; | |||||
| _batch_size = tf.convert_to_tensor(batch_size, dtype: TF_DataType.TF_INT64, name: "batch_size"); | |||||
| _drop_remainder = tf.convert_to_tensor(drop_remainder, dtype: TF_DataType.TF_BOOL, name: "drop_remainder"); | |||||
| if (drop_remainder) | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| else | |||||
| { | |||||
| _structure = input_dataset.element_spec.Select(x => x._batch(-1)).ToArray(); | |||||
| } | |||||
| variant_tensor = ops.batch_dataset_v2(input_dataset.variant_tensor, | |||||
| _batch_size, | |||||
| _drop_remainder, | |||||
| output_types, | |||||
| output_shapes); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,18 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Tensorflow.Framework.Models; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class DatasetSource : DatasetV2 | |||||
| { | |||||
| protected Tensor[] _tensors; | |||||
| public DatasetSource() | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,54 @@ | |||||
| using System; | |||||
| using System.Collections; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Tensorflow.Framework.Models; | |||||
| namespace Tensorflow | |||||
| { | |||||
| /// <summary> | |||||
| /// Abstract class representing a dataset with no inputs. | |||||
| /// </summary> | |||||
| public class DatasetV2 : IDatasetV2 | |||||
| { | |||||
| protected dataset_ops ops = new dataset_ops(); | |||||
| public Tensor variant_tensor { get; set; } | |||||
| public TensorSpec[] _structure { get; set; } | |||||
| public TensorShape[] output_shapes => _structure.Select(x => x.shape).ToArray(); | |||||
| public TF_DataType[] output_types => _structure.Select(x => x.dtype).ToArray(); | |||||
| public TensorSpec[] element_spec => _structure; | |||||
| public IDatasetV2 take(int count = -1) | |||||
| => new TakeDataset(this, count: count); | |||||
| public IDatasetV2 batch(int batch_size, bool drop_remainder = false) | |||||
| => new BatchDataset(this, batch_size, drop_remainder: drop_remainder); | |||||
| public IDatasetV2 prefetch(int buffer_size = -1, int? slack_period = null) | |||||
| => new PrefetchDataset(this, buffer_size: buffer_size, slack_period: slack_period); | |||||
| public IDatasetV2 repeat(int count = -1) | |||||
| => new RepeatDataset(this, count: count); | |||||
| public IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true) | |||||
| => new ShuffleDataset(this, buffer_size, seed: seed, reshuffle_each_iteration: reshuffle_each_iteration); | |||||
| public override string ToString() | |||||
| => $"{GetType().Name} shapes: ({_structure[0].shape}, {_structure[1].shape}), types: (tf.{_structure[0].dtype.as_numpy_name()}, tf.{_structure[1].dtype.as_numpy_name()})"; | |||||
| public IEnumerator<(Tensor, Tensor)> GetEnumerator() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| IEnumerator IEnumerable.GetEnumerator() | |||||
| { | |||||
| return this.GetEnumerator(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,11 +1,35 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Framework.Models; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public interface IDatasetV2 | |||||
| public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)> | |||||
| { | { | ||||
| Tensor variant_tensor { get; set; } | |||||
| TensorShape[] output_shapes { get; } | |||||
| TF_DataType[] output_types { get; } | |||||
| TensorSpec[] element_spec { get; } | |||||
| TensorSpec[] _structure { get; set; } | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="count"></param> | |||||
| /// <returns></returns> | |||||
| IDatasetV2 repeat(int count = -1); | |||||
| IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true); | |||||
| IDatasetV2 batch(int batch_size, bool drop_remainder = false); | |||||
| IDatasetV2 prefetch(int buffer_size = -1, int? slack_period = null); | |||||
| IDatasetV2 take(int count); | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,29 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | |||||
| { | |||||
| /// <summary> | |||||
| /// Creates a `Dataset` that prefetches elements from this dataset. | |||||
| /// </summary> | |||||
| public class PrefetchDataset : UnaryUnchangedStructureDataset | |||||
| { | |||||
| Tensor _buffer_size; | |||||
| public PrefetchDataset(IDatasetV2 input_dataset, | |||||
| long buffer_size = -1, | |||||
| int? slack_period = null) : | |||||
| base(input_dataset) | |||||
| { | |||||
| _buffer_size = tf.convert_to_tensor(buffer_size, dtype: TF_DataType.TF_INT64, name: "buffer_size"); | |||||
| variant_tensor = ops.prefetch_dataset(input_dataset.variant_tensor, | |||||
| _buffer_size, | |||||
| input_dataset.output_types, | |||||
| input_dataset.output_shapes, | |||||
| slack_period: slack_period); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,24 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| /// <summary> | |||||
| /// A `Dataset` that repeats its input several times. | |||||
| /// </summary> | |||||
| public class RepeatDataset : UnaryUnchangedStructureDataset | |||||
| { | |||||
| Tensor _count; | |||||
| public RepeatDataset(IDatasetV2 input_dataset, int count = -1) : | |||||
| base(input_dataset) | |||||
| { | |||||
| _count = constant_op.constant(count, dtype: TF_DataType.TF_INT64, name: "count"); | |||||
| variant_tensor = ops.repeat_dataset(input_dataset.variant_tensor, | |||||
| _count, | |||||
| input_dataset.output_types, | |||||
| input_dataset.output_shapes); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,37 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | |||||
| { | |||||
| /// <summary> | |||||
| /// Randomly shuffles the elements of this dataset. | |||||
| /// </summary> | |||||
| public class ShuffleDataset : UnaryUnchangedStructureDataset | |||||
| { | |||||
| Tensor _buffer_size; | |||||
| Tensor _seed; | |||||
| Tensor _seed2; | |||||
| bool _reshuffle_each_iteration; | |||||
| public ShuffleDataset(IDatasetV2 input_dataset, | |||||
| long buffer_size, | |||||
| int? seed = null, | |||||
| bool reshuffle_each_iteration = true) : | |||||
| base(input_dataset) | |||||
| { | |||||
| _buffer_size = tf.convert_to_tensor(buffer_size, dtype: TF_DataType.TF_INT64, name: "buffer_size"); | |||||
| (_seed, _seed2) = random_seed.get_seed_tensor(seed); | |||||
| _reshuffle_each_iteration = reshuffle_each_iteration; | |||||
| var seed_generator = ops.dummy_seed_generator(); | |||||
| if (tf.context.executing_eagerly()) | |||||
| variant_tensor = ops.shuffle_dataset_v3(input_dataset.variant_tensor, _buffer_size, | |||||
| _seed, _seed2, seed_generator, | |||||
| output_types, output_shapes, | |||||
| reshuffle_each_iteration: _reshuffle_each_iteration); | |||||
| else | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,20 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class TakeDataset : UnaryUnchangedStructureDataset | |||||
| { | |||||
| Tensor _count; | |||||
| public TakeDataset(IDatasetV2 input_dataset, int count) : | |||||
| base(input_dataset) | |||||
| { | |||||
| _count = tf.convert_to_tensor(count, dtype: dtypes.int64, name: "count"); | |||||
| variant_tensor = ops.take_dataset(input_dataset.variant_tensor, _count, | |||||
| output_types, output_shapes); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,19 +1,23 @@ | |||||
| using NumSharp; | using NumSharp; | ||||
| using NumSharp.Utilities; | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Framework.Models; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class TensorSliceDataset : IDatasetV2 | |||||
| public class TensorSliceDataset : DatasetSource | |||||
| { | { | ||||
| NDArray features; | |||||
| NDArray labels; | |||||
| public TensorSliceDataset(NDArray features, NDArray labels) | public TensorSliceDataset(NDArray features, NDArray labels) | ||||
| { | { | ||||
| this.features = features; | |||||
| this.labels = labels; | |||||
| _tensors = new[] { tf.convert_to_tensor(features), tf.convert_to_tensor(labels) }; | |||||
| var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray(); | |||||
| _structure = batched_spec.Select(x => x._unbatch()).ToArray(); | |||||
| variant_tensor = ops.tensor_slice_dataset(_tensors, output_shapes); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,21 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Framework.Models; | |||||
| namespace Tensorflow | |||||
| { | |||||
| /// <summary> | |||||
| /// Abstract class representing a dataset with one input. | |||||
| /// </summary> | |||||
| public class UnaryDataset : DatasetV2 | |||||
| { | |||||
| protected IDatasetV2 _input_dataset; | |||||
| public UnaryDataset(IDatasetV2 input_dataset) | |||||
| { | |||||
| _input_dataset = input_dataset; | |||||
| _structure = input_dataset._structure; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,18 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| /// <summary> | |||||
| /// Represents a unary dataset with the same input and output structure. | |||||
| /// </summary> | |||||
| public class UnaryUnchangedStructureDataset : UnaryDataset | |||||
| { | |||||
| public UnaryUnchangedStructureDataset(IDatasetV2 input_dataset) : | |||||
| base(input_dataset) | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -6,6 +6,7 @@ using System.Net; | |||||
| using System.Text; | using System.Text; | ||||
| using System.Threading; | using System.Threading; | ||||
| using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
| using Tensorflow.Framework.Models; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -7,6 +7,7 @@ using Google.Protobuf.WellKnownTypes; | |||||
| using System.Threading; | using System.Threading; | ||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| using System.Runtime.InteropServices.ComTypes; | using System.Runtime.InteropServices.ComTypes; | ||||
| using System.Runtime.InteropServices; | |||||
| namespace Tensorflow.Eager | namespace Tensorflow.Eager | ||||
| { | { | ||||
| @@ -73,10 +74,11 @@ namespace Tensorflow.Eager | |||||
| // Add inferred attrs and inputs. | // Add inferred attrs and inputs. | ||||
| for (int i = 0; i < op_def.InputArg.Count; i++) | for (int i = 0; i < op_def.InputArg.Count; i++) | ||||
| { | { | ||||
| var input = args[kFastPathExecuteInputStartIndex + i]; | |||||
| var input_arg = op_def.InputArg[i]; | var input_arg = op_def.InputArg[i]; | ||||
| if (!string.IsNullOrEmpty(input_arg.NumberAttr)) | if (!string.IsNullOrEmpty(input_arg.NumberAttr)) | ||||
| { | { | ||||
| int len = (args[kFastPathExecuteInputStartIndex + i] as object[]).Length; | |||||
| int len = (input as object[]).Length; | |||||
| c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, len); | c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, len); | ||||
| if (op_exec_info.run_callbacks) | if (op_exec_info.run_callbacks) | ||||
| { | { | ||||
| @@ -102,7 +104,31 @@ namespace Tensorflow.Eager | |||||
| } | } | ||||
| else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) | else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) | ||||
| { | { | ||||
| throw new NotImplementedException(""); | |||||
| var attr_name = input_arg.TypeListAttr; | |||||
| var fast_input_array = input as object[]; | |||||
| var len = fast_input_array.Length; | |||||
| var attr_values = new TF_DataType[len]; | |||||
| for (var j = 0; j < len; j++) | |||||
| { | |||||
| var eager_tensor = ops.convert_to_tensor(fast_input_array[j]); | |||||
| attr_values[j] = eager_tensor.dtype; | |||||
| c_api.TFE_OpAddInput(op, eager_tensor.EagerTensorHandle, status.Handle); | |||||
| if (op_exec_info.run_callbacks) | |||||
| { | |||||
| flattened_inputs.Add(eager_tensor); | |||||
| } | |||||
| } | |||||
| if (op_exec_info.run_callbacks) | |||||
| { | |||||
| flattened_attrs.Add(attr_name); | |||||
| flattened_attrs.Add(attr_values); | |||||
| } | |||||
| c_api.TFE_OpSetAttrTypeList(op, attr_name, attr_values, attr_values.Length); | |||||
| attr_list_sizes[attr_name] = len; | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -206,7 +232,7 @@ namespace Tensorflow.Eager | |||||
| break; | break; | ||||
| default: | default: | ||||
| var tensor = tf.convert_to_tensor(inputs); | var tensor = tf.convert_to_tensor(inputs); | ||||
| input_handle = (tensor as EagerTensor).EagerTensorHandle; | |||||
| input_handle = tensor.EagerTensorHandle; | |||||
| break; | break; | ||||
| } | } | ||||
| @@ -237,7 +263,7 @@ namespace Tensorflow.Eager | |||||
| var type = c_api.TFE_OpGetAttrType(op, key, ref is_list, status.Handle); | var type = c_api.TFE_OpGetAttrType(op, key, ref is_list, status.Handle); | ||||
| if (!status.ok()) return; | if (!status.ok()) return; | ||||
| if (is_list != 0) | if (is_list != 0) | ||||
| SetOpAttrList(tf.context, op, key, value, type, null, status); | |||||
| SetOpAttrList(tf.context, op, key, value as object[], type, null, status); | |||||
| else | else | ||||
| SetOpAttrScalar(tf.context, op, key, value, type, null, status); | SetOpAttrScalar(tf.context, op, key, value, type, null, status); | ||||
| status.Check(true); | status.Check(true); | ||||
| @@ -282,20 +308,45 @@ namespace Tensorflow.Eager | |||||
| else | else | ||||
| { | { | ||||
| if (is_list != 0) | if (is_list != 0) | ||||
| #pragma warning disable CS0642 // Possible mistaken empty statement | |||||
| ;// SetOpAttrList | |||||
| #pragma warning restore CS0642 // Possible mistaken empty statement | |||||
| SetOpAttrList(ctx, op, attr_name, attr_value, type, attr_list_sizes, status); | |||||
| else | else | ||||
| SetOpAttrScalar(ctx, op, attr_name, attr_value, type, attr_list_sizes, status); | SetOpAttrScalar(ctx, op, attr_name, attr_value, type, attr_list_sizes, status); | ||||
| } | } | ||||
| } | } | ||||
| bool SetOpAttrList(Context ctx, SafeOpHandle op, | bool SetOpAttrList(Context ctx, SafeOpHandle op, | ||||
| string key, object value, TF_AttrType type, | |||||
| string key, object values, TF_AttrType type, | |||||
| Dictionary<string, long> attr_list_sizes, | Dictionary<string, long> attr_list_sizes, | ||||
| Status status) | Status status) | ||||
| { | { | ||||
| return false; | |||||
| if(type == TF_AttrType.TF_ATTR_SHAPE && values is TensorShape[] values1) | |||||
| { | |||||
| // Make one pass through the input counting the total number of | |||||
| // dims across all the input lists. | |||||
| var num_values = values1.Length; | |||||
| attr_list_sizes[key] = num_values; | |||||
| var dims = new IntPtr[num_values]; | |||||
| var num_dims = values1.Select(x => x.ndim).ToArray(); | |||||
| for (int i = 0; i < num_values; ++i) | |||||
| { | |||||
| dims[i] = Marshal.AllocHGlobal(sizeof(long) * values1[i].ndim); | |||||
| tf.memcpy(dims[i], values1[i].dims.Select(x => (long)x).ToArray(), values1[i].ndim); | |||||
| } | |||||
| c_api.TFE_OpSetAttrShapeList(op, key, dims, num_dims, num_values, status.Handle); | |||||
| Array.ForEach(dims, x => Marshal.FreeHGlobal(x)); | |||||
| } | |||||
| else if(type == TF_AttrType.TF_ATTR_TYPE && values is TF_DataType[] values2) | |||||
| { | |||||
| c_api.TFE_OpSetAttrTypeList(op, key, values2, values2.Length); | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| return true; | |||||
| } | } | ||||
| bool SetOpAttrScalar(Context ctx, SafeOpHandle op, | bool SetOpAttrScalar(Context ctx, SafeOpHandle op, | ||||
| @@ -0,0 +1,45 @@ | |||||
| using NumSharp; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Eager | |||||
| { | |||||
| public partial class EagerTensor | |||||
| { | |||||
| public override string ToString() | |||||
| { | |||||
| switch (rank) | |||||
| { | |||||
| case -1: | |||||
| return $"tf.Tensor: shape={TensorShape}, dtype={dtype.as_numpy_name()}, numpy={GetFormattedString(dtype, numpy())}"; | |||||
| case 0: | |||||
| return $"tf.Tensor: shape={TensorShape}, dtype={dtype.as_numpy_name()}, numpy={GetFormattedString(dtype, numpy())}"; | |||||
| default: | |||||
| return $"tf.Tensor: shape={TensorShape}, dtype={dtype.as_numpy_name()}, numpy={GetFormattedString(dtype, numpy())}"; | |||||
| } | |||||
| } | |||||
| public static string GetFormattedString(TF_DataType dtype, NDArray nd) | |||||
| { | |||||
| if (nd.size == 0) | |||||
| return "[]"; | |||||
| switch (dtype) | |||||
| { | |||||
| case TF_DataType.TF_STRING: | |||||
| return string.Join(string.Empty, nd.ToArray<byte>() | |||||
| .Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString())); | |||||
| case TF_DataType.TF_BOOL: | |||||
| return (nd.GetByte(0) > 0).ToString(); | |||||
| case TF_DataType.TF_VARIANT: | |||||
| case TF_DataType.TF_RESOURCE: | |||||
| return "<unprintable>"; | |||||
| default: | |||||
| return nd.ToString(); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -9,7 +9,6 @@ namespace Tensorflow.Eager | |||||
| { | { | ||||
| public partial class EagerTensor : Tensor | public partial class EagerTensor : Tensor | ||||
| { | { | ||||
| public IntPtr EagerTensorHandle; | |||||
| public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(EagerTensorHandle, tf.status.Handle)); | public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(EagerTensorHandle, tf.status.Handle)); | ||||
| public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.status.Handle); | public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.status.Handle); | ||||
| @@ -28,37 +27,5 @@ namespace Tensorflow.Eager | |||||
| dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, tf.status.Handle); | dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, tf.status.Handle); | ||||
| return dims; | return dims; | ||||
| } | } | ||||
| public override string ToString() | |||||
| { | |||||
| switch (rank) | |||||
| { | |||||
| case -1: | |||||
| return $"tf.Tensor: shape=<unknown>, dtype={dtype.as_numpy_name()}, numpy={GetFormattedString(dtype, numpy())}"; | |||||
| case 0: | |||||
| return $"tf.Tensor: shape=(), dtype={dtype.as_numpy_name()}, numpy={GetFormattedString(dtype, numpy())}"; | |||||
| default: | |||||
| return $"tf.Tensor: shape=({string.Join(",", shape)}), dtype={dtype.as_numpy_name()}, numpy={GetFormattedString(dtype, numpy())}"; | |||||
| } | |||||
| } | |||||
| public static string GetFormattedString(TF_DataType dtype, NDArray nd) | |||||
| { | |||||
| if (nd.size == 0) | |||||
| return "[]"; | |||||
| switch (dtype) | |||||
| { | |||||
| case TF_DataType.TF_STRING: | |||||
| return string.Join(string.Empty, nd.ToArray<byte>() | |||||
| .Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString())); | |||||
| case TF_DataType.TF_BOOL: | |||||
| return (nd.GetByte(0) > 0).ToString(); | |||||
| case TF_DataType.TF_RESOURCE: | |||||
| return "<unprintable>"; | |||||
| default: | |||||
| return nd.ToString(); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,4 +1,5 @@ | |||||
| using System; | |||||
| using Google.Protobuf; | |||||
| using System; | |||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using Tensorflow.Device; | using Tensorflow.Device; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| @@ -156,6 +157,9 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFE_OpSetAttrShape(SafeOpHandle op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status); | public static extern void TFE_OpSetAttrShape(SafeOpHandle op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status); | ||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TFE_OpSetAttrShapeList(SafeOpHandle op, string attr_name, IntPtr[] dims, int[] num_dims, int num_values, SafeStatusHandle out_status); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFE_OpSetAttrBool(SafeOpHandle op, string attr_name, bool value); | public static extern void TFE_OpSetAttrBool(SafeOpHandle op, string attr_name, bool value); | ||||
| @@ -168,6 +172,12 @@ namespace Tensorflow | |||||
| /// <param name="length">size_t</param> | /// <param name="length">size_t</param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFE_OpSetAttrString(SafeOpHandle op, string attr_name, string value, uint length); | public static extern void TFE_OpSetAttrString(SafeOpHandle op, string attr_name, string value, uint length); | ||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TFE_OpSetAttrTypeList(SafeOpHandle op, string attr_name, TF_DataType[] values, int num_values); | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TFE_OpSetAttrValueProto(SafeOpHandle op, string attr_name, IMessage[] proto, int proto_len, SafeStatusHandle status); | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| @@ -0,0 +1,31 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Framework.Models | |||||
| { | |||||
| /// <summary> | |||||
| /// Describes a dense object with shape, dtype, and name. | |||||
| /// </summary> | |||||
| public class DenseSpec : TypeSpec | |||||
| { | |||||
| protected TensorShape _shape; | |||||
| public TensorShape shape => _shape; | |||||
| protected TF_DataType _dtype; | |||||
| public TF_DataType dtype => _dtype; | |||||
| protected string _name; | |||||
| public string name => _name; | |||||
| public DenseSpec(int[] shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | |||||
| { | |||||
| _shape = new TensorShape(shape); | |||||
| _dtype = dtype; | |||||
| _name = name; | |||||
| } | |||||
| public override string ToString() | |||||
| => $"shape={_shape}, dtype={_dtype.as_numpy_name()}, name={_name}"; | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,31 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Framework.Models | |||||
| { | |||||
| public class TensorSpec : DenseSpec | |||||
| { | |||||
| public TensorSpec(int[] shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) : | |||||
| base(shape, dtype, name) | |||||
| { | |||||
| } | |||||
| public TensorSpec _unbatch() | |||||
| { | |||||
| if (_shape.ndim == 0) | |||||
| throw new ValueError("Unbatching a tensor is only supported for rank >= 1"); | |||||
| return new TensorSpec(_shape.dims[1..], _dtype); | |||||
| } | |||||
| public TensorSpec _batch(int dim = -1) | |||||
| { | |||||
| var shapes = shape.dims.ToList(); | |||||
| shapes.Insert(0, dim); | |||||
| return new TensorSpec(shapes.ToArray(), _dtype); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,13 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Framework.Models | |||||
| { | |||||
| /// <summary> | |||||
| /// Specifies a TensorFlow value type. | |||||
| /// </summary> | |||||
| public class TypeSpec | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -27,5 +27,22 @@ namespace Tensorflow | |||||
| else | else | ||||
| return (null, null); | return (null, null); | ||||
| } | } | ||||
| public static (Tensor, Tensor) get_seed_tensor(int? op_seed = null) | |||||
| { | |||||
| var (seed, seed2) = get_seed(op_seed); | |||||
| Tensor _seed, _seed2; | |||||
| if (seed is null) | |||||
| _seed = constant_op.constant(0, dtype: TF_DataType.TF_INT64, name: "seed"); | |||||
| else | |||||
| _seed = constant_op.constant(seed.Value, dtype: TF_DataType.TF_INT64, name: "seed"); | |||||
| if (seed2 is null) | |||||
| _seed2 = constant_op.constant(0, dtype: TF_DataType.TF_INT64, name: "seed2"); | |||||
| else | |||||
| _seed2 = constant_op.constant(seed2.Value, dtype: TF_DataType.TF_INT64, name: "seed2"); | |||||
| return (_seed, _seed2); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -48,8 +48,8 @@ namespace Tensorflow.Keras.Optimizers | |||||
| } | } | ||||
| var device_dtype = _apply_state.Keys.FirstOrDefault(x => x.Device == var.Device && x.DType == var.dtype.as_base_dtype()); | var device_dtype = _apply_state.Keys.FirstOrDefault(x => x.Device == var.Device && x.DType == var.dtype.as_base_dtype()); | ||||
| return gen_training_ops.resource_apply_gradient_descent(var.Handle as EagerTensor, | |||||
| _apply_state[device_dtype]["lr_t"] as EagerTensor, | |||||
| return gen_training_ops.resource_apply_gradient_descent(var.Handle, | |||||
| _apply_state[device_dtype]["lr_t"], | |||||
| grad, | grad, | ||||
| use_locking: _use_locking); | use_locking: _use_locking); | ||||
| } | } | ||||
| @@ -0,0 +1,178 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class dataset_ops | |||||
| { | |||||
| /// <summary> | |||||
| /// Creates a dataset that emits each dim-0 slice of `components` once. | |||||
| /// </summary> | |||||
| /// <param name="components"></param> | |||||
| /// <param name="output_shapes"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public Tensor tensor_slice_dataset(Tensor[] components, TensorShape[] output_shapes, string name = null) | |||||
| { | |||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "TensorSliceDataset", name, | |||||
| null, | |||||
| new object[] | |||||
| { | |||||
| components, | |||||
| "output_shapes", output_shapes | |||||
| }); | |||||
| return results[0]; | |||||
| } | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| public Tensor repeat_dataset(Tensor input_dataset, Tensor count, TF_DataType[] output_types, TensorShape[] output_shapes, string name = null) | |||||
| { | |||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "RepeatDataset", name, | |||||
| null, | |||||
| input_dataset, count, | |||||
| "output_types", output_types, | |||||
| "output_shapes", output_shapes); | |||||
| return results[0]; | |||||
| } | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| public Tensor shuffle_dataset_v3(Tensor input_dataset, Tensor buffer_size, | |||||
| Tensor seed, Tensor seed2, Tensor seed_generator, | |||||
| TF_DataType[] output_types, TensorShape[] output_shapes, | |||||
| bool reshuffle_each_iteration = true, | |||||
| string name = null) | |||||
| { | |||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "ShuffleDatasetV3", name, | |||||
| null, | |||||
| input_dataset, buffer_size, | |||||
| seed, seed2, seed_generator, | |||||
| "reshuffle_each_iteration", reshuffle_each_iteration, | |||||
| "output_types", output_types, | |||||
| "output_shapes", output_shapes); | |||||
| return results[0]; | |||||
| } | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| public Tensor dummy_seed_generator(string name = null) | |||||
| { | |||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "DummySeedGenerator", name, | |||||
| null); | |||||
| return results[0]; | |||||
| } | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| /// <summary> | |||||
| /// Creates a dataset that batches `batch_size` elements from `input_dataset`. | |||||
| /// </summary> | |||||
| /// <param name="input_dataset"></param> | |||||
| /// <param name="buffer_size"></param> | |||||
| /// <param name="drop_remainder"></param> | |||||
| /// <param name="output_types"></param> | |||||
| /// <param name="output_shapes"></param> | |||||
| /// <param name="parallel_copy"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public Tensor batch_dataset_v2(Tensor input_dataset, Tensor buffer_size, | |||||
| Tensor drop_remainder, | |||||
| TF_DataType[] output_types, TensorShape[] output_shapes, | |||||
| bool parallel_copy = false, | |||||
| string name = null) | |||||
| { | |||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "BatchDatasetV2", name, | |||||
| null, | |||||
| input_dataset, buffer_size, drop_remainder, | |||||
| "parallel_copy", parallel_copy, | |||||
| "output_types", output_types, | |||||
| "output_shapes", output_shapes); | |||||
| return results[0]; | |||||
| } | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| /// <summary> | |||||
| /// Creates a dataset that asynchronously prefetches elements from `input_dataset`. | |||||
| /// </summary> | |||||
| /// <param name="input_dataset"></param> | |||||
| /// <param name="buffer_size"></param> | |||||
| /// <param name="output_types"></param> | |||||
| /// <param name="output_shapes"></param> | |||||
| /// <param name="slack_period"></param> | |||||
| /// <param name="legacy_autotune"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public Tensor prefetch_dataset(Tensor input_dataset, Tensor buffer_size, | |||||
| TF_DataType[] output_types, TensorShape[] output_shapes, | |||||
| int? slack_period = 0, | |||||
| bool legacy_autotune = true, | |||||
| string name = null) | |||||
| { | |||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "PrefetchDataset", name, | |||||
| null, | |||||
| input_dataset, buffer_size, | |||||
| "output_types", output_types, | |||||
| "output_shapes", output_shapes, | |||||
| "slack_period", slack_period, | |||||
| "legacy_autotune", legacy_autotune); | |||||
| return results[0]; | |||||
| } | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| /// <summary> | |||||
| /// Creates a dataset that contains `count` elements from the `input_dataset`. | |||||
| /// </summary> | |||||
| /// <param name="input_dataset"></param> | |||||
| /// <param name="count"></param> | |||||
| /// <param name="output_types"></param> | |||||
| /// <param name="output_shapes"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public Tensor take_dataset(Tensor input_dataset, Tensor count, | |||||
| TF_DataType[] output_types, TensorShape[] output_shapes, | |||||
| string name = null) | |||||
| { | |||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "TakeDataset", name, | |||||
| null, | |||||
| input_dataset, count, | |||||
| "output_types", output_types, | |||||
| "output_shapes", output_shapes); | |||||
| return results[0]; | |||||
| } | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,274 @@ | |||||
| // https://github.com/dotnet/corefx/blob/1597b894a2e9cac668ce6e484506eca778a85197/src/Common/src/CoreLib/System/Index.cs | |||||
| // https://github.com/dotnet/corefx/blob/1597b894a2e9cac668ce6e484506eca778a85197/src/Common/src/CoreLib/System/Range.cs | |||||
| using System.Runtime.CompilerServices; | |||||
| namespace System | |||||
| { | |||||
| /// <summary>Represent a type can be used to index a collection either from the start or the end.</summary> | |||||
| /// <remarks> | |||||
| /// Index is used by the C# compiler to support the new index syntax | |||||
| /// <code> | |||||
| /// int[] someArray = new int[5] { 1, 2, 3, 4, 5 } ; | |||||
| /// int lastElement = someArray[^1]; // lastElement = 5 | |||||
| /// </code> | |||||
| /// </remarks> | |||||
| internal readonly struct Index : IEquatable<Index> | |||||
| { | |||||
| private readonly int _value; | |||||
| /// <summary>Construct an Index using a value and indicating if the index is from the start or from the end.</summary> | |||||
| /// <param name="value">The index value. it has to be zero or positive number.</param> | |||||
| /// <param name="fromEnd">Indicating if the index is from the start or from the end.</param> | |||||
| /// <remarks> | |||||
| /// If the Index constructed from the end, index value 1 means pointing at the last element and index value 0 means pointing at beyond last element. | |||||
| /// </remarks> | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| public Index(int value, bool fromEnd = false) | |||||
| { | |||||
| if (value < 0) | |||||
| { | |||||
| throw new ArgumentOutOfRangeException(nameof(value), "value must be non-negative"); | |||||
| } | |||||
| if (fromEnd) | |||||
| _value = ~value; | |||||
| else | |||||
| _value = value; | |||||
| } | |||||
| // The following private constructors mainly created for perf reason to avoid the checks | |||||
| private Index(int value) | |||||
| { | |||||
| _value = value; | |||||
| } | |||||
| /// <summary>Create an Index pointing at first element.</summary> | |||||
| public static Index Start => new Index(0); | |||||
| /// <summary>Create an Index pointing at beyond last element.</summary> | |||||
| public static Index End => new Index(~0); | |||||
| /// <summary>Create an Index from the start at the position indicated by the value.</summary> | |||||
| /// <param name="value">The index value from the start.</param> | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| public static Index FromStart(int value) | |||||
| { | |||||
| if (value < 0) | |||||
| { | |||||
| throw new ArgumentOutOfRangeException(nameof(value), "value must be non-negative"); | |||||
| } | |||||
| return new Index(value); | |||||
| } | |||||
| /// <summary>Create an Index from the end at the position indicated by the value.</summary> | |||||
| /// <param name="value">The index value from the end.</param> | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| public static Index FromEnd(int value) | |||||
| { | |||||
| if (value < 0) | |||||
| { | |||||
| throw new ArgumentOutOfRangeException(nameof(value), "value must be non-negative"); | |||||
| } | |||||
| return new Index(~value); | |||||
| } | |||||
| /// <summary>Returns the index value.</summary> | |||||
| public int Value | |||||
| { | |||||
| get | |||||
| { | |||||
| if (_value < 0) | |||||
| { | |||||
| return ~_value; | |||||
| } | |||||
| else | |||||
| { | |||||
| return _value; | |||||
| } | |||||
| } | |||||
| } | |||||
| /// <summary>Indicates whether the index is from the start or the end.</summary> | |||||
| public bool IsFromEnd => _value < 0; | |||||
| /// <summary>Calculate the offset from the start using the giving collection length.</summary> | |||||
| /// <param name="length">The length of the collection that the Index will be used with. length has to be a positive value</param> | |||||
| /// <remarks> | |||||
| /// For performance reason, we don't validate the input length parameter and the returned offset value against negative values. | |||||
| /// we don't validate either the returned offset is greater than the input length. | |||||
| /// It is expected Index will be used with collections which always have non negative length/count. If the returned offset is negative and | |||||
| /// then used to index a collection will get out of range exception which will be same affect as the validation. | |||||
| /// </remarks> | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| public int GetOffset(int length) | |||||
| { | |||||
| var offset = _value; | |||||
| if (IsFromEnd) | |||||
| { | |||||
| // offset = length - (~value) | |||||
| // offset = length + (~(~value) + 1) | |||||
| // offset = length + value + 1 | |||||
| offset += length + 1; | |||||
| } | |||||
| return offset; | |||||
| } | |||||
| /// <summary>Indicates whether the current Index object is equal to another object of the same type.</summary> | |||||
| /// <param name="value">An object to compare with this object</param> | |||||
| public override bool Equals(object? value) => value is Index && _value == ((Index)value)._value; | |||||
| /// <summary>Indicates whether the current Index object is equal to another Index object.</summary> | |||||
| /// <param name="other">An object to compare with this object</param> | |||||
| public bool Equals(Index other) => _value == other._value; | |||||
| /// <summary>Returns the hash code for this instance.</summary> | |||||
| public override int GetHashCode() => _value; | |||||
| /// <summary>Converts integer number to an Index.</summary> | |||||
| public static implicit operator Index(int value) => FromStart(value); | |||||
| /// <summary>Converts the value of the current Index object to its equivalent string representation.</summary> | |||||
| public override string ToString() | |||||
| { | |||||
| if (IsFromEnd) | |||||
| return "^" + ((uint)Value).ToString(); | |||||
| return ((uint)Value).ToString(); | |||||
| } | |||||
| } | |||||
| /// <summary>Represent a range has start and end indexes.</summary> | |||||
| /// <remarks> | |||||
| /// Range is used by the C# compiler to support the range syntax. | |||||
| /// <code> | |||||
| /// int[] someArray = new int[5] { 1, 2, 3, 4, 5 }; | |||||
| /// int[] subArray1 = someArray[0..2]; // { 1, 2 } | |||||
| /// int[] subArray2 = someArray[1..^0]; // { 2, 3, 4, 5 } | |||||
| /// </code> | |||||
| /// </remarks> | |||||
| internal readonly struct Range : IEquatable<Range> | |||||
| { | |||||
| /// <summary>Represent the inclusive start index of the Range.</summary> | |||||
| public Index Start { get; } | |||||
| /// <summary>Represent the exclusive end index of the Range.</summary> | |||||
| public Index End { get; } | |||||
| /// <summary>Construct a Range object using the start and end indexes.</summary> | |||||
| /// <param name="start">Represent the inclusive start index of the range.</param> | |||||
| /// <param name="end">Represent the exclusive end index of the range.</param> | |||||
| public Range(Index start, Index end) | |||||
| { | |||||
| Start = start; | |||||
| End = end; | |||||
| } | |||||
| /// <summary>Indicates whether the current Range object is equal to another object of the same type.</summary> | |||||
| /// <param name="value">An object to compare with this object</param> | |||||
| public override bool Equals(object? value) => | |||||
| value is Range r && | |||||
| r.Start.Equals(Start) && | |||||
| r.End.Equals(End); | |||||
| /// <summary>Indicates whether the current Range object is equal to another Range object.</summary> | |||||
| /// <param name="other">An object to compare with this object</param> | |||||
| public bool Equals(Range other) => other.Start.Equals(Start) && other.End.Equals(End); | |||||
| /// <summary>Returns the hash code for this instance.</summary> | |||||
| public override int GetHashCode() | |||||
| { | |||||
| return Start.GetHashCode() * 31 + End.GetHashCode(); | |||||
| } | |||||
| /// <summary>Converts the value of the current Range object to its equivalent string representation.</summary> | |||||
| public override string ToString() | |||||
| { | |||||
| return Start + ".." + End; | |||||
| } | |||||
| /// <summary>Create a Range object starting from start index to the end of the collection.</summary> | |||||
| public static Range StartAt(Index start) => new Range(start, Index.End); | |||||
| /// <summary>Create a Range object starting from first element in the collection to the end Index.</summary> | |||||
| public static Range EndAt(Index end) => new Range(Index.Start, end); | |||||
| /// <summary>Create a Range object starting from first element to the end.</summary> | |||||
| public static Range All => new Range(Index.Start, Index.End); | |||||
| /// <summary>Calculate the start offset and length of range object using a collection length.</summary> | |||||
| /// <param name="length">The length of the collection that the range will be used with. length has to be a positive value.</param> | |||||
| /// <remarks> | |||||
| /// For performance reason, we don't validate the input length parameter against negative values. | |||||
| /// It is expected Range will be used with collections which always have non negative length/count. | |||||
| /// We validate the range is inside the length scope though. | |||||
| /// </remarks> | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| public (int Offset, int Length) GetOffsetAndLength(int length) | |||||
| { | |||||
| int start; | |||||
| var startIndex = Start; | |||||
| if (startIndex.IsFromEnd) | |||||
| start = length - startIndex.Value; | |||||
| else | |||||
| start = startIndex.Value; | |||||
| int end; | |||||
| var endIndex = End; | |||||
| if (endIndex.IsFromEnd) | |||||
| end = length - endIndex.Value; | |||||
| else | |||||
| end = endIndex.Value; | |||||
| if ((uint)end > (uint)length || (uint)start > (uint)end) | |||||
| { | |||||
| throw new ArgumentOutOfRangeException(nameof(length)); | |||||
| } | |||||
| return (start, end - start); | |||||
| } | |||||
| } | |||||
| } | |||||
| namespace System.Runtime.CompilerServices | |||||
| { | |||||
| internal static class RuntimeHelpers | |||||
| { | |||||
| /// <summary> | |||||
| /// Slices the specified array using the specified range. | |||||
| /// </summary> | |||||
| public static T[] GetSubArray<T>(T[] array, Range range) | |||||
| { | |||||
| if (array == null) | |||||
| { | |||||
| throw new ArgumentNullException(nameof(array)); | |||||
| } | |||||
| (int offset, int length) = range.GetOffsetAndLength(array.Length); | |||||
| if (default(T) != null || typeof(T[]) == array.GetType()) | |||||
| { | |||||
| // We know the type of the array to be exactly T[]. | |||||
| if (length == 0) | |||||
| { | |||||
| return Array.Empty<T>(); | |||||
| } | |||||
| var dest = new T[length]; | |||||
| Array.Copy(array, offset, dest, 0, length); | |||||
| return dest; | |||||
| } | |||||
| else | |||||
| { | |||||
| // The array is actually a U[] where U:T. | |||||
| var dest = (T[])Array.CreateInstance(array.GetType().GetElementType(), length); | |||||
| Array.Copy(array, offset, dest, 0, length); | |||||
| return dest; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -5,7 +5,7 @@ | |||||
| <AssemblyName>TensorFlow.NET</AssemblyName> | <AssemblyName>TensorFlow.NET</AssemblyName> | ||||
| <RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
| <TargetTensorFlow>2.2.0</TargetTensorFlow> | <TargetTensorFlow>2.2.0</TargetTensorFlow> | ||||
| <Version>0.20.0-preview2</Version> | |||||
| <Version>0.20.0-preview3</Version> | |||||
| <LangVersion>8.0</LangVersion> | <LangVersion>8.0</LangVersion> | ||||
| <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | ||||
| <Company>SciSharp STACK</Company> | <Company>SciSharp STACK</Company> | ||||
| @@ -22,6 +22,7 @@ using System.Runtime.CompilerServices; | |||||
| using System.Text; | using System.Text; | ||||
| using NumSharp.Utilities; | using NumSharp.Utilities; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using Tensorflow.Framework.Models; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -395,5 +396,8 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| public TensorSpec ToTensorSpec() | |||||
| => new TensorSpec(shape, dtype, name); | |||||
| } | } | ||||
| } | } | ||||
| @@ -38,8 +38,7 @@ namespace Tensorflow | |||||
| _TensorLike, | _TensorLike, | ||||
| ITensorOrTensorArray, | ITensorOrTensorArray, | ||||
| IPackable<Tensor>, | IPackable<Tensor>, | ||||
| ICanBeFlattened, | |||||
| IPointerInputs | |||||
| ICanBeFlattened | |||||
| { | { | ||||
| protected long _id; | protected long _id; | ||||
| private readonly Operation _op; | private readonly Operation _op; | ||||
| @@ -93,9 +92,9 @@ namespace Tensorflow | |||||
| public object Tag { get; set; } | public object Tag { get; set; } | ||||
| /// <summary> | /// <summary> | ||||
| /// Associated resource variable | |||||
| /// TFE_TensorHandle | |||||
| /// </summary> | /// </summary> | ||||
| public ResourceVariable ResourceVar { get; set; } | |||||
| public IntPtr EagerTensorHandle { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the shape of a tensor. | /// Returns the shape of a tensor. | ||||
| @@ -254,7 +254,15 @@ namespace Tensorflow | |||||
| public override string ToString() | public override string ToString() | ||||
| { | { | ||||
| return shape.ToString(); | |||||
| switch (rank) | |||||
| { | |||||
| case -1: | |||||
| return $"<unknown>"; | |||||
| case 0: | |||||
| return $"()"; | |||||
| default: | |||||
| return $"{string.Join(",", shape).Replace("-1", "None")}"; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -201,9 +201,11 @@ namespace Tensorflow | |||||
| TF_DataType.TF_STRING => "string", | TF_DataType.TF_STRING => "string", | ||||
| TF_DataType.TF_UINT8 => "uint8", | TF_DataType.TF_UINT8 => "uint8", | ||||
| TF_DataType.TF_INT32 => "int32", | TF_DataType.TF_INT32 => "int32", | ||||
| TF_DataType.TF_INT64 => "int64", | |||||
| TF_DataType.TF_FLOAT => "float32", | TF_DataType.TF_FLOAT => "float32", | ||||
| TF_DataType.TF_BOOL => "bool", | TF_DataType.TF_BOOL => "bool", | ||||
| TF_DataType.TF_RESOURCE => "resource", | TF_DataType.TF_RESOURCE => "resource", | ||||
| TF_DataType.TF_VARIANT => "variant", | |||||
| _ => type.ToString() | _ => type.ToString() | ||||
| }; | }; | ||||
| @@ -1,11 +0,0 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public interface IPointerInputs | |||||
| { | |||||
| public IntPtr ToPointer(); | |||||
| } | |||||
| } | |||||
| @@ -49,6 +49,8 @@ namespace Tensorflow | |||||
| { | { | ||||
| if (src.Length == 0) return; | if (src.Length == 0) return; | ||||
| size = size * sizeof(T); | |||||
| fixed (void* p = &src[0]) | fixed (void* p = &src[0]) | ||||
| System.Buffer.MemoryCopy(p, dst.ToPointer(), size, size); | System.Buffer.MemoryCopy(p, dst.ToPointer(), size, size); | ||||
| } | } | ||||