# Conflicts: # src/TensorFlowNET.Core/Tensorflow.Binding.csproj # src/TensorFlowNET.Keras/Datasets/Imdb.cstags/v0.110.4-Transformer-Model
| @@ -16,6 +16,7 @@ | |||
| using System; | |||
| using System.Runtime.InteropServices; | |||
| using static Tensorflow.CppShapeInferenceResult.Types; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -50,6 +51,35 @@ namespace Tensorflow | |||
| return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle); | |||
| } | |||
| public unsafe static byte[] ByteStringPiece(Buffer? handle) | |||
| { | |||
| if (handle is null) | |||
| { | |||
| return new byte[0]; | |||
| } | |||
| var data = handle.ToArray(); | |||
| return data; | |||
| } | |||
| public unsafe static byte[] ByteStringPieceFromNativeString(IntPtr handle) | |||
| { | |||
| if (handle == IntPtr.Zero) | |||
| { | |||
| return new byte[0]; | |||
| } | |||
| byte* str_data = (byte*)handle.ToPointer(); | |||
| List<byte> bytes = new List<byte>(); | |||
| byte current = 255; | |||
| while (current != ((byte)'\0')) | |||
| { | |||
| current = *(str_data++); | |||
| bytes.Add(current); | |||
| } | |||
| var data = bytes.ToArray(); | |||
| return data; | |||
| } | |||
| [UnmanagedFunctionPointer(CallingConvention.Winapi)] | |||
| public delegate void Deallocator(IntPtr data, IntPtr size, ref DeallocatorArgs args); | |||
| @@ -10,7 +10,7 @@ namespace Tensorflow | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||
| public static extern SafeBufferHandle TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status); | |||
| } | |||
| @@ -91,8 +91,7 @@ namespace Tensorflow | |||
| return identity(values.First(), name: scope); | |||
| }); | |||
| } | |||
| return gen_array_ops.concat_v2(values.ToArray(), ops.convert_to_tensor(axis), name: name); | |||
| return array_ops.concat(values.ToArray(), axis, name: name); | |||
| } | |||
| /// <summary> | |||
| @@ -163,14 +162,17 @@ namespace Tensorflow | |||
| /// Reverses specific dimensions of a tensor. | |||
| /// </summary> | |||
| /// <param name="tensor"></param> | |||
| /// <param name="axis"></param> | |||
| /// <param name="axis">The indices of the dimensions to reverse. Must be in the range [-rank(tensor), rank(tensor)).</param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public Tensor reverse(Tensor tensor, int[] axis, string name = null) | |||
| => gen_array_ops.reverse(tensor, ops.convert_to_tensor(axis), name: name); | |||
| public Tensor reverse(Tensor tensor, Tensor axis, string name = null) | |||
| => gen_array_ops.reverse(tensor, axis, name: name); | |||
| public Tensor reverse(Tensor tensor, Axis axis, string name = null) | |||
| { | |||
| if (axis.IsScalar) | |||
| { | |||
| axis = new Axis(axis.axis); | |||
| } | |||
| return array_ops.reverse(tensor, axis, name: name); | |||
| } | |||
| /// <summary> | |||
| /// Returns the rank of a tensor. | |||
| @@ -46,10 +46,10 @@ namespace Tensorflow | |||
| Tensor loop_vars, | |||
| int parallel_iterations = 10) | |||
| { | |||
| Func<Tensor[], Tensor> cond1 = x | |||
| Func<Tensors, Tensor> cond1 = x | |||
| => cond(x[0]); | |||
| Func<Tensor[], Tensor[]> body1 = x | |||
| Func<Tensors, Tensors> body1 = x | |||
| => new[] { body(x[0]) }; | |||
| var results = control_flow_ops.while_loop(cond1, | |||
| @@ -58,9 +58,9 @@ namespace Tensorflow | |||
| return results[0]; | |||
| } | |||
| public Tensor[] while_loop(Func<Tensor[], Tensor> cond, | |||
| Func<Tensor[], Tensor[]> body, | |||
| Tensor[] loop_vars, | |||
| public Tensor[] while_loop(Func<Tensors, Tensor> cond, | |||
| Func<Tensors, Tensors> body, | |||
| Tensors loop_vars, | |||
| int parallel_iterations = 10, | |||
| string name = null) | |||
| => control_flow_ops.while_loop(cond, body, loop_vars, | |||
| @@ -14,6 +14,10 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using OneOf.Types; | |||
| using System; | |||
| using System.Buffers.Text; | |||
| using Tensorflow.Contexts; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| @@ -162,17 +166,108 @@ namespace Tensorflow | |||
| public Tensor sobel_edges(Tensor image) | |||
| => image_ops_impl.sobel_edges(image); | |||
| public Tensor decode_jpeg(Tensor contents, | |||
| int channels = 0, | |||
| int ratio = 1, | |||
| bool fancy_upscaling = true, | |||
| bool try_recover_truncated = false, | |||
| int acceptable_fraction = 1, | |||
| string dct_method = "", | |||
| string name = null) | |||
| => gen_image_ops.decode_jpeg(contents, channels: channels, ratio: ratio, | |||
| fancy_upscaling: fancy_upscaling, try_recover_truncated: try_recover_truncated, | |||
| acceptable_fraction: acceptable_fraction, dct_method: dct_method); | |||
| /// <summary> | |||
| /// Adjust contrast of RGB or grayscale images. | |||
| /// </summary> | |||
| /// <param name="images">Images to adjust. At least 3-D.</param> | |||
| /// <param name="contrast_factor"></param> | |||
| /// <param name="name">A float multiplier for adjusting contrast.</param> | |||
| /// <returns>The contrast-adjusted image or images.</returns> | |||
| public Tensor adjust_contrast(Tensor images, float contrast_factor, string name = null) | |||
| => gen_image_ops.adjust_contrastv2(images, contrast_factor, name); | |||
| /// <summary> | |||
| /// Adjust hue of RGB images. | |||
| /// </summary> | |||
| /// <param name="images">RGB image or images. The size of the last dimension must be 3.</param> | |||
| /// <param name="delta">float. How much to add to the hue channel.</param> | |||
| /// <param name="name">A name for this operation (optional).</param> | |||
| /// <returns>Adjusted image(s), same shape and DType as `image`.</returns> | |||
| /// <exception cref="ValueError">if `delta` is not in the interval of `[-1, 1]`.</exception> | |||
| public Tensor adjust_hue(Tensor images, float delta, string name = null) | |||
| { | |||
| if (tf.Context.executing_eagerly()) | |||
| { | |||
| if (delta < -1f || delta > 1f) | |||
| throw new ValueError("delta must be in the interval [-1, 1]"); | |||
| } | |||
| return gen_image_ops.adjust_hue(images, delta, name: name); | |||
| } | |||
| /// <summary> | |||
| /// Adjust saturation of RGB images. | |||
| /// </summary> | |||
| /// <param name="image">RGB image or images. The size of the last dimension must be 3.</param> | |||
| /// <param name="saturation_factor">float. Factor to multiply the saturation by.</param> | |||
| /// <param name="name">A name for this operation (optional).</param> | |||
| /// <returns>Adjusted image(s), same shape and DType as `image`.</returns> | |||
| public Tensor adjust_saturation(Tensor image, float saturation_factor, string name = null) | |||
| => gen_image_ops.adjust_saturation(image, saturation_factor, name); | |||
| /// <summary> | |||
| /// Greedily selects a subset of bounding boxes in descending order of score. | |||
| /// </summary> | |||
| /// <param name="boxes"> | |||
| /// A 4-D float `Tensor` of shape `[batch_size, num_boxes, q, 4]`. If `q` | |||
| /// is 1 then same boxes are used for all classes otherwise, if `q` is equal | |||
| /// to number of classes, class-specific boxes are used. | |||
| /// </param> | |||
| /// <param name="scores"> | |||
| /// A 3-D float `Tensor` of shape `[batch_size, num_boxes, num_classes]` | |||
| /// representing a single score corresponding to each box(each row of boxes). | |||
| /// </param> | |||
| /// <param name="max_output_size_per_class"> | |||
| /// A scalar integer `Tensor` representing the | |||
| /// maximum number of boxes to be selected by non-max suppression per class | |||
| /// </param> | |||
| /// <param name="max_total_size"> | |||
| /// A int32 scalar representing maximum number of boxes retained | |||
| /// over all classes.Note that setting this value to a large number may | |||
| /// result in OOM error depending on the system workload. | |||
| /// </param> | |||
| /// <param name="iou_threshold"> | |||
| /// A float representing the threshold for deciding whether boxes | |||
| /// overlap too much with respect to IOU. | |||
| /// </param> | |||
| /// <param name="score_threshold"> | |||
| /// A float representing the threshold for deciding when to | |||
| /// remove boxes based on score. | |||
| /// </param> | |||
| /// <param name="pad_per_class"> | |||
| /// If false, the output nmsed boxes, scores and classes are | |||
| /// padded/clipped to `max_total_size`. If true, the output nmsed boxes, scores and classes are padded to be of length `max_size_per_class`*`num_classes`, | |||
| /// unless it exceeds `max_total_size` in which case it is clipped to `max_total_size`. Defaults to false. | |||
| /// </param> | |||
| /// <param name="clip_boxes"> | |||
| /// If true, the coordinates of output nmsed boxes will be clipped | |||
| /// to[0, 1]. If false, output the box coordinates as it is. Defaults to true. | |||
| /// </param> | |||
| /// <returns> | |||
| /// 'nmsed_boxes': A [batch_size, max_detections, 4] float32 tensor containing the non-max suppressed boxes. | |||
| /// 'nmsed_scores': A [batch_size, max_detections] float32 tensor containing the scores for the boxes. | |||
| /// 'nmsed_classes': A [batch_size, max_detections] float32 tensor containing the class for boxes. | |||
| /// 'valid_detections': A [batch_size] int32 tensor indicating the number of | |||
| /// valid detections per batch item. Only the top valid_detections[i] entries | |||
| /// in nms_boxes[i], nms_scores[i] and nms_class[i] are valid. The rest of the | |||
| /// entries are zero paddings. | |||
| /// </returns> | |||
| public (Tensor, Tensor, Tensor, Tensor) combined_non_max_suppression( | |||
| Tensor boxes, | |||
| Tensor scores, | |||
| int max_output_size_per_class, | |||
| int max_total_size, | |||
| float iou_threshold, | |||
| float score_threshold, | |||
| bool pad_per_class = false, | |||
| bool clip_boxes = true) | |||
| { | |||
| var iou_threshold_t = ops.convert_to_tensor(iou_threshold, TF_DataType.TF_FLOAT, name: "iou_threshold"); | |||
| var score_threshold_t = ops.convert_to_tensor(score_threshold, TF_DataType.TF_FLOAT, name: "score_threshold"); | |||
| var max_total_size_t = ops.convert_to_tensor(max_total_size); | |||
| var max_output_size_per_class_t = ops.convert_to_tensor(max_output_size_per_class); | |||
| return gen_image_ops.combined_non_max_suppression(boxes, scores, max_output_size_per_class_t, max_total_size_t, | |||
| iou_threshold_t, score_threshold_t, pad_per_class, clip_boxes); | |||
| } | |||
| /// <summary> | |||
| /// Extracts crops from the input image tensor and resizes them using bilinear sampling or nearest neighbor sampling (possibly with aspect ratio change) to a common output size specified by crop_size. This is more general than the crop_to_bounding_box op which extracts a fixed size slice from the input image and does not allow resizing or aspect ratio change. | |||
| @@ -187,7 +282,19 @@ namespace Tensorflow | |||
| /// <param name="name">A name for the operation (optional).</param> | |||
| /// <returns>A 4-D tensor of shape [num_boxes, crop_height, crop_width, depth].</returns> | |||
| public Tensor crop_and_resize(Tensor image, Tensor boxes, Tensor box_ind, Tensor crop_size, string method = "bilinear", float extrapolation_value = 0f, string name = null) => | |||
| image_ops_impl.crop_and_resize(image, boxes, box_ind, crop_size, method, extrapolation_value, name); | |||
| gen_image_ops.crop_and_resize(image, boxes, box_ind, crop_size, method, extrapolation_value, name); | |||
| public Tensor decode_jpeg(Tensor contents, | |||
| int channels = 0, | |||
| int ratio = 1, | |||
| bool fancy_upscaling = true, | |||
| bool try_recover_truncated = false, | |||
| int acceptable_fraction = 1, | |||
| string dct_method = "", | |||
| string name = null) | |||
| => gen_image_ops.decode_jpeg(contents, channels: channels, ratio: ratio, | |||
| fancy_upscaling: fancy_upscaling, try_recover_truncated: try_recover_truncated, | |||
| acceptable_fraction: acceptable_fraction, dct_method: dct_method); | |||
| public Tensor extract_glimpse(Tensor input, Tensor size, Tensor offsets, bool centered = true, bool normalized = true, | |||
| bool uniform_noise = true, string name = null) | |||
| @@ -14,6 +14,7 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using Tensorflow.NumPy; | |||
| using Tensorflow.Operations; | |||
| namespace Tensorflow | |||
| @@ -42,10 +43,20 @@ namespace Tensorflow | |||
| public Tensor multiply(Tensor x, Tensor y, string name = null) | |||
| => math_ops.multiply(x, y, name: name); | |||
| public Tensor divide_no_nan(Tensor a, Tensor b, string name = null) | |||
| => math_ops.div_no_nan(a, b); | |||
| /// <summary> | |||
| /// Computes the Euclidean norm of elements across dimensions of a tensor. | |||
| /// </summary> | |||
| /// <param name="input_tensor">The tensor to reduce. Should have numeric type.</param> | |||
| /// <param name="axis">The dimensions to reduce. If `None` (the default), reduces all dimensions.Must be in the range `[-rank(input_tensor), rank(input_tensor))`</param> | |||
| /// <param name="keepdims">If true, retains reduced dimensions with length 1.</param> | |||
| /// <param name="name">A name for the operation (optional).</param> | |||
| /// <returns>The reduced tensor, of the same dtype as the input_tensor.</returns> | |||
| public Tensor reduce_euclidean_norm(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) | |||
| => math_ops.reduce_euclidean_norm(input_tensor, axis: axis, keepdims: keepdims, name); | |||
| public Tensor square(Tensor x, string name = null) | |||
| => math_ops.square(x, name: name); | |||
| @@ -354,7 +365,7 @@ namespace Tensorflow | |||
| => a / b; | |||
| public Tensor sqrt(Tensor a, string name = null) | |||
| => gen_math_ops.sqrt(a, name); | |||
| => math_ops.sqrt(a, name); | |||
| public Tensor sign(Tensor a, string name = null) | |||
| => gen_math_ops.sign(a, name); | |||
| @@ -452,7 +463,18 @@ namespace Tensorflow | |||
| /// <returns></returns> | |||
| public Tensor multiply<Tx, Ty>(Tx x, Ty y, string name = null) | |||
| => gen_math_ops.mul(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name); | |||
| /// <summary> | |||
| /// return scalar product | |||
| /// </summary> | |||
| /// <typeparam name="Tx"></typeparam> | |||
| /// <typeparam name="Ty"></typeparam> | |||
| /// <param name="x"></param> | |||
| /// <param name="y"></param> | |||
| /// <param name="axes"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public Tensor dot_prod<Tx, Ty>(Tx x, Ty y, NDArray axes, string name = null) | |||
| => math_ops.tensordot(convert_to_tensor(x), convert_to_tensor(y), axes, name: name); | |||
| public Tensor negative(Tensor x, string name = null) | |||
| => gen_math_ops.neg(x, name); | |||
| @@ -600,5 +622,7 @@ namespace Tensorflow | |||
| => gen_math_ops.squared_difference(x: x, y: y, name: name); | |||
| public Tensor complex(Tensor real, Tensor imag, Tensorflow.TF_DataType? dtype = null, | |||
| string name = null) => gen_ops.complex(real, imag, dtype, name); | |||
| public Tensor exp(Tensor x, | |||
| string name = null) => gen_math_ops.exp(x, name); | |||
| } | |||
| } | |||
| @@ -14,6 +14,7 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System.Xml.Linq; | |||
| using Tensorflow.Operations; | |||
| using Tensorflow.Operations.Activation; | |||
| using static Tensorflow.Binding; | |||
| @@ -126,6 +127,26 @@ namespace Tensorflow | |||
| name: name, | |||
| exponential_avg_factor: exponential_avg_factor); | |||
| /// <summary> | |||
| /// Normalizes a tensor by `mean` and `variance`, and applies (optionally) a`scale` \\(\gamma\\) to it, as well as an `offset` \\(\beta\\). | |||
| /// </summary> | |||
| /// <param name="x">A floating point tensor.</param> | |||
| /// <param name="mean">A mean `Tensor`.</param> | |||
| /// <param name="variance">A variance `Tensor`.</param> | |||
| /// <param name="offset"> An offset `Tensor`, often denoted \\(\beta\\) in equations, or NULL. If present, will be added to the normalized tensor.</param> | |||
| /// <param name="scale"> A scale `Tensor`, often denoted \\(\gamma\\) in equations, or NULL. If present, the scale is applied to the normalized tensor.</param> | |||
| /// <param name="variance_epsilon"> A small float number to avoid dividing by 0.</param> | |||
| /// <param name="name">A name for this operation.</param> | |||
| /// <returns>the normalized, scaled, offset tensor.</returns> | |||
| public Tensor batch_normalization(Tensor x, | |||
| Tensor mean, | |||
| Tensor variance, | |||
| Tensor offset, | |||
| Tensor scale, | |||
| float variance_epsilon, | |||
| string name = null) => nn_impl.batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name); | |||
| public Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null) | |||
| => nn_ops.max_pool(value, ksize, strides, padding, data_format: data_format, name: name); | |||
| @@ -31,6 +31,6 @@ namespace Tensorflow | |||
| public Tensor reshape(Tensor tensor, | |||
| object[] shape, | |||
| string name = null) | |||
| => gen_array_ops.reshape(tensor, ops.convert_to_tensor(shape), name); | |||
| => array_ops.reshape(tensor, shape, name); | |||
| } | |||
| } | |||
| @@ -68,20 +68,27 @@ namespace Tensorflow | |||
| /// <param name="name">A name for the operation (optional)</param> | |||
| /// <returns>if num_or_size_splits is a scalar returns num_or_size_splits Tensor objects; | |||
| /// if num_or_size_splits is a 1-D Tensor returns num_or_size_splits.get_shape[0] Tensor objects resulting from splitting value.</returns> | |||
| public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null) | |||
| public Tensor[] split(Tensor value, int num_split, Axis axis, string name = null) | |||
| => array_ops.split( | |||
| value: value, | |||
| num_split: num_split, | |||
| num_or_size_splits: num_split, | |||
| axis: axis, | |||
| name: name); | |||
| public Tensor[] split(Tensor value, int num_split, int axis, string name = null) | |||
| public Tensor[] split(Tensor value, int[] num_split, Axis axis, string name = null) | |||
| => array_ops.split( | |||
| value: value, | |||
| num_split: num_split, | |||
| num_or_size_splits: num_split, | |||
| axis: axis, | |||
| name: name); | |||
| //public Tensor[] split(Tensor value, int num_split, Axis axis, string name = null) | |||
| // => array_ops.split( | |||
| // value: value, | |||
| // num_or_size_splits: num_split, | |||
| // axis: axis, | |||
| // name: name); | |||
| public Tensor ensure_shape(Tensor x, Shape shape, string name = null) | |||
| { | |||
| return gen_ops.ensure_shape(x, shape, name); | |||
| @@ -23,7 +23,7 @@ namespace Tensorflow | |||
| => gen_array_ops.tile(input, multiples, name); | |||
| public Tensor tile(Tensor input, object[] multiples, string name = null) | |||
| => gen_array_ops.tile(input, ops.convert_to_tensor(multiples), name); | |||
| => array_ops.tile(input, constant_op.constant(shape_utils.from_object_array(multiples).dims), name); | |||
| public Tensor tile(Tensor input, Shape multiples, string name = null) | |||
| { | |||
| @@ -486,7 +486,28 @@ namespace Tensorflow | |||
| throw new NotImplementedException(""); | |||
| } | |||
| } | |||
| public static NDArray GetFlattenArray(NDArray x) | |||
| { | |||
| switch (x.GetDataType()) | |||
| { | |||
| case TF_DataType.TF_FLOAT: | |||
| x = x.ToArray<float>(); | |||
| break; | |||
| case TF_DataType.TF_DOUBLE: | |||
| x = x.ToArray<double>(); | |||
| break; | |||
| case TF_DataType.TF_INT16: | |||
| case TF_DataType.TF_INT32: | |||
| x = x.ToArray<int>(); | |||
| break; | |||
| case TF_DataType.TF_INT64: | |||
| x = x.ToArray<long>(); | |||
| break; | |||
| default: | |||
| break; | |||
| } | |||
| return x; | |||
| } | |||
| public static TF_DataType GetDataType(this object data) | |||
| { | |||
| var type = data.GetType(); | |||
| @@ -503,7 +524,7 @@ namespace Tensorflow | |||
| case Tensors tensors: | |||
| return tensors.dtype; | |||
| case IEnumerable<Tensor> tensors: | |||
| return tensors.First().dtype; | |||
| return tensors.Where(x => x is not null).First().dtype; | |||
| case RefVariable variable: | |||
| return variable.dtype; | |||
| case ResourceVariable variable: | |||
| @@ -3,16 +3,16 @@ using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Extensions | |||
| namespace Tensorflow.Common.Extensions | |||
| { | |||
| public static class JObjectExtensions | |||
| { | |||
| public static T? TryGetOrReturnNull<T>(this JObject obj, string key) | |||
| { | |||
| var res = obj[key]; | |||
| if(res is null) | |||
| if (res is null) | |||
| { | |||
| return default(T); | |||
| return default; | |||
| } | |||
| else | |||
| { | |||
| @@ -0,0 +1,38 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Extensions | |||
| { | |||
| public static class LinqExtensions | |||
| { | |||
| #if NETSTANDARD2_0 | |||
| public static IEnumerable<T> TakeLast<T>(this IEnumerable<T> sequence, int count) | |||
| { | |||
| return sequence.Skip(sequence.Count() - count); | |||
| } | |||
| public static IEnumerable<T> SkipLast<T>(this IEnumerable<T> sequence, int count) | |||
| { | |||
| return sequence.Take(sequence.Count() - count); | |||
| } | |||
| #endif | |||
| public static Tensors ToTensors(this Tensor[] tensors) | |||
| { | |||
| return new Tensors(tensors); | |||
| } | |||
| public static Tensors ToTensors(this IList<Tensor> tensors) | |||
| { | |||
| return new Tensors(tensors); | |||
| } | |||
| public static void Deconstruct<T1, T2, T3>(this (T1, T2, T3) values, out T1 first, out T2 second, out T3 third) | |||
| { | |||
| first = values.Item1; | |||
| second = values.Item2; | |||
| third = values.Item3; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,33 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Common.Extensions | |||
| { | |||
| public static class NestExtensions | |||
| { | |||
| public static Tensors ToTensors(this INestable<Tensor> tensors) | |||
| { | |||
| return new Tensors(tensors.AsNest()); | |||
| } | |||
| public static Tensors? ToTensors(this Nest<Tensor> tensors) | |||
| { | |||
| return Tensors.FromNest(tensors); | |||
| } | |||
| /// <summary> | |||
| /// If the nested object is already a nested type, this function could reduce it. | |||
| /// For example, `Nest[Nest[T]]` can be reduced to `Nest[T]`. | |||
| /// </summary> | |||
| /// <typeparam name="TIn"></typeparam> | |||
| /// <typeparam name="TOut"></typeparam> | |||
| /// <param name="input"></param> | |||
| /// <returns></returns> | |||
| public static Nest<TOut> ReduceTo<TIn, TOut>(this INestStructure<TIn> input) where TIn: INestStructure<TOut> | |||
| { | |||
| return Nest<TOut>.ReduceFrom(input); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,20 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| /// <summary> | |||
| /// This is a temp solution, which should be removed after refactoring `Tensors` | |||
| /// </summary> | |||
| [Obsolete] | |||
| public class FakeTensorByTensorArray: Tensor | |||
| { | |||
| public TensorArray TensorArray { get; set; } | |||
| public FakeTensorByTensorArray(TensorArray array) | |||
| { | |||
| TensorArray = array; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,69 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| public class GeneralizedTensorShape: Nest<Shape> | |||
| { | |||
| public GeneralizedTensorShape(Shape value, string? name = null) | |||
| { | |||
| NodeValue = value; | |||
| NestType = NestType.Node; | |||
| } | |||
| public GeneralizedTensorShape(IEnumerable<Shape> values, string? name = null) | |||
| { | |||
| ListValue = values.Select(s => new Nest<Shape>(s) as INestStructure<Shape>).ToList(); | |||
| Name = name; | |||
| NestType = NestType.List; | |||
| } | |||
| public GeneralizedTensorShape(Dictionary<string, Shape> value, string? name = null) | |||
| { | |||
| DictValue = value.ToDictionary(x => x.Key, x => new Nest<Shape>(x.Value) as INestStructure<Shape>); | |||
| Name = name; | |||
| NestType = NestType.Dictionary; | |||
| } | |||
| public GeneralizedTensorShape(Nest<Shape> other) | |||
| { | |||
| NestType = other.NestType; | |||
| NodeValue = other.NodeValue; | |||
| DictValue = other.DictValue; | |||
| ListValue = other.ListValue; | |||
| Name = other.Name; | |||
| } | |||
| public Shape ToSingleShape() | |||
| { | |||
| var shapes = Flatten().ToList(); | |||
| if (shapes.Count != 1) | |||
| { | |||
| throw new ValueError("The generalized shape contains more than 1 dim."); | |||
| } | |||
| return shapes[0]; | |||
| } | |||
| public long ToNumber() | |||
| { | |||
| var shapes = Flatten().ToList(); | |||
| if (shapes.Count != 1 || shapes[0].ndim != 1) | |||
| { | |||
| throw new ValueError("The generalized shape contains more than 1 dim."); | |||
| } | |||
| return shapes[0].dims[0]; | |||
| } | |||
| public INestStructure<TensorShapeConfig> ToTensorShapeConfigs() | |||
| { | |||
| return MapStructure(s => new TensorShapeConfig() { Items = s.dims.Select<long, long?>(x => x == -1 ? null : x).ToArray() }); | |||
| } | |||
| public static implicit operator GeneralizedTensorShape(Shape shape) | |||
| { | |||
| return new GeneralizedTensorShape(shape); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,40 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| /// <summary> | |||
| /// This interface indicates that a class may have a nested structure and provide | |||
| /// methods to manipulate with the structure. | |||
| /// </summary> | |||
| public interface INestStructure<T>: INestable<T> | |||
| { | |||
| NestType NestType { get; } | |||
| /// <summary> | |||
| /// The item count of depth 1 of the nested structure. | |||
| /// For example, [1, 2, [3, 4, 5]] has ShallowNestedCount = 3. | |||
| /// </summary> | |||
| int ShallowNestedCount { get; } | |||
| /// <summary> | |||
| /// The total item count of depth 1 of the nested structure. | |||
| /// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5. | |||
| /// </summary> | |||
| int TotalNestedCount { get; } | |||
| /// <summary> | |||
| /// Flatten the Nestable object. Node that if the object contains only one value, | |||
| /// it will be flattened to an enumerable with one element. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| IEnumerable<T> Flatten(); | |||
| /// <summary> | |||
| /// Construct a new object with the same nested structure. | |||
| /// </summary> | |||
| /// <typeparam name="TOut"></typeparam> | |||
| /// <param name="func"></param> | |||
| /// <returns></returns> | |||
| INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func); | |||
| } | |||
| } | |||
| @@ -0,0 +1,11 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| public interface INestable<T> | |||
| { | |||
| Nest<T> AsNest(); | |||
| } | |||
| } | |||
| @@ -0,0 +1,21 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| /// <summary> | |||
| /// This interface is used when some corresponding python methods have optional args. | |||
| /// For example, `Keras.Layer.Apply` generally takes three args as the inputs, while | |||
| /// `Keras.Layer.RNN` takes more. Then when calling RNN, you should add `RnnOptionalArgs` | |||
| /// as the parameter of the method. | |||
| /// </summary> | |||
| public interface IOptionalArgs | |||
| { | |||
| /// <summary> | |||
| /// The identifier of the class. It is not an argument but only something to | |||
| /// separate different OptionalArgs. | |||
| /// </summary> | |||
| string Identifier { get; } | |||
| } | |||
| } | |||
| @@ -0,0 +1,62 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| public static class Nest | |||
| { | |||
| /// <summary> | |||
| /// Pack the flat items to a nested sequence by the template. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| /// <param name="template"></param> | |||
| /// <param name="flatItems"></param> | |||
| /// <returns></returns> | |||
| public static Nest<TOut> PackSequenceAs<T, TOut>(INestable<T> template, TOut[] flatItems) | |||
| { | |||
| return template.AsNest().PackSequence(flatItems); | |||
| } | |||
| /// <summary> | |||
| /// Pack the flat items to a nested sequence by the template. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| /// <param name="template"></param> | |||
| /// <param name="flatItems"></param> | |||
| /// <returns></returns> | |||
| public static Nest<T> PackSequenceAs<T>(INestable<T> template, List<T> flatItems) | |||
| { | |||
| return template.AsNest().PackSequence(flatItems.ToArray()); | |||
| } | |||
| /// <summary> | |||
| /// Flatten the nested object. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| /// <param name="nestedObject"></param> | |||
| /// <returns></returns> | |||
| public static IEnumerable<T> Flatten<T>(INestable<T> nestedObject) | |||
| { | |||
| return nestedObject.AsNest().Flatten(); | |||
| } | |||
| /// <summary> | |||
| /// Map the structure with specified function. | |||
| /// </summary> | |||
| /// <typeparam name="TIn"></typeparam> | |||
| /// <typeparam name="TOut"></typeparam> | |||
| /// <param name="func"></param> | |||
| /// <param name="nestedObject"></param> | |||
| /// <returns></returns> | |||
| public static INestStructure<TOut> MapStructure<TIn, TOut>(Func<TIn, TOut> func, INestable<TIn> nestedObject) | |||
| { | |||
| return nestedObject.AsNest().MapStructure(func); | |||
| } | |||
| public static bool IsNested<T>(INestable<T> obj) | |||
| { | |||
| return obj.AsNest().IsNested(); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,485 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Common.Extensions; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| public enum NestType | |||
| { | |||
| Empty, | |||
| Node, | |||
| List, | |||
| Dictionary | |||
| } | |||
| /// <summary> | |||
| /// A nested structure which may inclulde value, list and dictionary. | |||
| /// Note that dictionary does not ensure the data order. When using it as IEnumerable, | |||
| /// its order is depth-first. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| public class Nest<T> : INestStructure<T>, IEnumerable<T> | |||
| { | |||
| private static readonly Nest<T> _empty = new Nest<T>() | |||
| { | |||
| NestType = NestType.Empty, | |||
| }; | |||
| public static Nest<T> Empty => _empty; | |||
| public NestType NestType { get; protected set; } | |||
| public string? Name { get; set; } | |||
| public T? NodeValue { get; protected set; } | |||
| public List<INestStructure<T>>? ListValue { get; protected set; } | |||
| public Dictionary<string, INestStructure<T>>? DictValue { get; protected set; } | |||
| public int ShallowNestedCount | |||
| { | |||
| get | |||
| { | |||
| if (NestType == NestType.Empty) | |||
| { | |||
| return 0; | |||
| } | |||
| else if (NestType == NestType.Node) | |||
| { | |||
| return 1; | |||
| } | |||
| else if (NestType == NestType.List) | |||
| { | |||
| return ListValue!.Count; | |||
| } | |||
| else // dict | |||
| { | |||
| return DictValue!.Count; | |||
| } | |||
| } | |||
| } | |||
| public int TotalNestedCount | |||
| { | |||
| get | |||
| { | |||
| return Flatten().Count(); | |||
| } | |||
| } | |||
| protected Nest() { } | |||
| public Nest(T value, string? name = null) | |||
| { | |||
| NodeValue = value; | |||
| Name = name; | |||
| NestType = NestType.Node; | |||
| } | |||
| public Nest(IEnumerable<INestStructure<T>> values, string? name = null) | |||
| { | |||
| ListValue = values.ToList(); | |||
| Name = name; | |||
| NestType = NestType.List; | |||
| } | |||
| public Nest(Dictionary<string, INestStructure<T>> value, string? name = null) | |||
| { | |||
| DictValue = value; | |||
| Name = name; | |||
| NestType = NestType.Dictionary; | |||
| } | |||
| public Nest(Nest<T> other) | |||
| { | |||
| NestType = other.NestType; | |||
| NodeValue = other.NodeValue; | |||
| DictValue = other.DictValue; | |||
| ListValue = other.ListValue; | |||
| Name = other.Name; | |||
| } | |||
| public virtual IEnumerable<T> Flatten() | |||
| { | |||
| return FlattenInternal(this); | |||
| } | |||
| public virtual INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||
| { | |||
| return MapStructureInternal(func); | |||
| } | |||
| /// <summary> | |||
| /// Pack the flat items to a nested sequence by the template. | |||
| /// </summary> | |||
| /// <param name="flatItems"></param> | |||
| /// <returns></returns> | |||
| public virtual Nest<TOut> PackSequence<TOut>(TOut[] flatItems) | |||
| { | |||
| if(flatItems.Length == 0) | |||
| { | |||
| return Nest<TOut>.Empty; | |||
| } | |||
| int index = 0; | |||
| return PackSequenceInternal(this, flatItems, ref index); | |||
| } | |||
| private static Nest<TOut> PackSequenceInternal<TOut>(Nest<T> template, TOut[] flatItems, ref int index) | |||
| { | |||
| if(template.NestType == NestType.Node) | |||
| { | |||
| if(index >= flatItems.Length) | |||
| { | |||
| throw new InvalidArgumentError("The template and flat items are not matched."); | |||
| } | |||
| return new Nest<TOut>(flatItems[index++]); | |||
| } | |||
| else if(template.NestType == NestType.List) | |||
| { | |||
| List<Nest<TOut>> nestedObjects = new List<Nest<TOut>>(); | |||
| for (int i = 0; i < template.ListValue!.Count; i++) | |||
| { | |||
| nestedObjects.Add(PackSequenceInternal(template.ListValue![i].AsNest(), flatItems, ref index)); | |||
| } | |||
| return new Nest<TOut>(nestedObjects); | |||
| } | |||
| else if(template.NestType == NestType.Node) | |||
| { | |||
| Dictionary<string, INestStructure<TOut>> dict = new Dictionary<string, INestStructure<TOut>>(); | |||
| foreach(var (key, value) in template.DictValue!) | |||
| { | |||
| dict[key] = PackSequenceInternal(value.AsNest(), flatItems, ref index); | |||
| } | |||
| return new Nest<TOut>(dict); | |||
| } | |||
| // Consider Empty as invalid type. | |||
| throw new InvalidArgumentError("When using `PackSequenceAs`, the template cannot contain empty node."); | |||
| } | |||
| public virtual Nest<T> AsNest() | |||
| { | |||
| return this; | |||
| } | |||
| public virtual Nest<T> MergeWith(Nest<T>? other) | |||
| { | |||
| if(other is null || other == Nest<T>.Empty) | |||
| { | |||
| return this; | |||
| } | |||
| if(this == Nest<T>.Empty) | |||
| { | |||
| return other; | |||
| } | |||
| if(NestType == NestType.Node && other.NestType == NestType.Node) | |||
| { | |||
| return new Nest<T>(new Nest<T>[] { this, other }); | |||
| } | |||
| else if(NestType == NestType.List && other.NestType == NestType.List) | |||
| { | |||
| return new Nest<T>(this.ListValue!.Concat(other.ListValue!)); | |||
| } | |||
| else if(NestType == NestType.Dictionary && other.NestType == NestType.Dictionary) | |||
| { | |||
| return new Nest<T>(this.DictValue!.Concat(other.DictValue!).ToDictionary(x => x.Key, x => x.Value)); | |||
| } | |||
| else | |||
| { | |||
| return new Nest<T>(new Nest<T>[] { this, other }); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// To see if the nested object is really nested. Despite being called `Nest`, sometimes it's actually not | |||
| /// nested. For example, [1, 2, 3] is not nested, while [1, [2, 3]] is nested. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public bool IsNested() | |||
| { | |||
| if(NestType is NestType.Empty or NestType.Node) | |||
| { | |||
| return false; | |||
| } | |||
| else if(NestType is NestType.List) | |||
| { | |||
| return ListValue!.Count > 0; | |||
| } | |||
| else | |||
| { | |||
| return DictValue!.Count > 0; | |||
| } | |||
| } | |||
| [Obsolete("The indexer of Tensors is not encouraged because it leads to unclear meanings.")] | |||
| public T this[int index] | |||
| { | |||
| get | |||
| { | |||
| bool success = FindInternal(this, index, out var result); | |||
| if (success) | |||
| { | |||
| return result; | |||
| } | |||
| else | |||
| { | |||
| throw new IndexOutOfRangeException(); | |||
| } | |||
| } | |||
| set | |||
| { | |||
| bool success = SetInternal(this, index, value); | |||
| if (!success) | |||
| { | |||
| throw new IndexOutOfRangeException(); | |||
| } | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// If the existing nested structure if of type `Nest[INestStructure[T]]`, we can reduce it | |||
| /// to `Nest[T]`. | |||
| /// </summary> | |||
| /// <typeparam name="TOut"></typeparam> | |||
| /// <param name="input"></param> | |||
| /// <returns></returns> | |||
| public static Nest<T> ReduceFrom<TOut>(INestStructure<TOut> input) where TOut: INestStructure<T> | |||
| { | |||
| var nested = input.AsNest(); | |||
| return ReduceInternal(nested).AsNest(); | |||
| } | |||
| private static INestStructure<T> ReduceInternal<TOut>(Nest<TOut> node) where TOut : INestStructure<T> | |||
| { | |||
| if(node.NestType == NestType.Empty) | |||
| { | |||
| return Nest<T>.Empty; | |||
| } | |||
| else if(node.NestType == NestType.Node) | |||
| { | |||
| return node.NodeValue!.AsNest(); | |||
| } | |||
| else if(node.NestType == NestType.List) | |||
| { | |||
| return new Nest<T>(node.ListValue!.Select(x => ReduceInternal(x.AsNest()))); | |||
| } | |||
| else // Dictionary type | |||
| { | |||
| return new Nest<T>(node.DictValue!.ToDictionary(x => x.Key, x => ReduceInternal(x.Value.AsNest()))); | |||
| } | |||
| } | |||
| private static bool FindInternal(Nest<T> node, int index, out T? result) | |||
| { | |||
| if (node.NestType == NestType.Node) | |||
| { | |||
| if(index == 0) | |||
| { | |||
| result = node.NodeValue!; | |||
| return true; | |||
| } | |||
| result = default(T); | |||
| return false; | |||
| } | |||
| else if (node.NestType == NestType.List) | |||
| { | |||
| foreach (var item in node.ListValue!) | |||
| { | |||
| if(index == 0) | |||
| { | |||
| return FindInternal(item.AsNest(), index, out result); | |||
| } | |||
| index--; | |||
| } | |||
| result = default(T); | |||
| return false; | |||
| } | |||
| else if(node.NestType == NestType.Dictionary) | |||
| { | |||
| foreach (var item in node.DictValue!.Values) | |||
| { | |||
| if (index == 0) | |||
| { | |||
| return FindInternal(item.AsNest(), index, out result); | |||
| } | |||
| index--; | |||
| } | |||
| result = default(T); | |||
| return false; | |||
| } | |||
| else | |||
| { | |||
| result = default(T); | |||
| return false; | |||
| } | |||
| } | |||
| private static bool SetInternal(Nest<T> node, int index, T newValue) | |||
| { | |||
| if (node.NestType == NestType.Node) | |||
| { | |||
| if (index == 0) | |||
| { | |||
| node.NodeValue = newValue; | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| else if (node.NestType == NestType.List) | |||
| { | |||
| foreach (var item in node.ListValue!) | |||
| { | |||
| if (index == 0) | |||
| { | |||
| return SetInternal(item.AsNest(), index, newValue); | |||
| } | |||
| index--; | |||
| } | |||
| return false; | |||
| } | |||
| else if (node.NestType == NestType.Dictionary) | |||
| { | |||
| foreach (var item in node.DictValue!.Values) | |||
| { | |||
| if (index == 0) | |||
| { | |||
| return SetInternal(item.AsNest(), index, newValue); | |||
| } | |||
| index--; | |||
| } | |||
| return false; | |||
| } | |||
| else | |||
| { | |||
| return false; | |||
| } | |||
| } | |||
| private static IEnumerable<T> FlattenInternal(Nest<T> node) | |||
| { | |||
| if (node.NestType == NestType.Node) | |||
| { | |||
| yield return node.NodeValue!; | |||
| } | |||
| else if (node.NestType == NestType.List) | |||
| { | |||
| foreach (var item in node.ListValue!) | |||
| { | |||
| foreach(var val in FlattenInternal(item.AsNest())) | |||
| { | |||
| yield return val; | |||
| } | |||
| } | |||
| } | |||
| else if (node.NestType == NestType.Dictionary) | |||
| { | |||
| foreach (var item in node.DictValue!.Values) | |||
| { | |||
| foreach (var val in FlattenInternal(item.AsNest())) | |||
| { | |||
| yield return val; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| private Nest<TOut> MapStructureInternal<TOut>(Func<T, TOut> func) | |||
| { | |||
| if (NestType == NestType.Node) | |||
| { | |||
| return new Nest<TOut>(func(NodeValue!)); | |||
| } | |||
| else if (NestType == NestType.List) | |||
| { | |||
| List<Nest<TOut>> outs = new List<Nest<TOut>>(); | |||
| foreach (var item in ListValue!) | |||
| { | |||
| outs.Add(item.AsNest().MapStructureInternal(func)); | |||
| } | |||
| return new Nest<TOut>(outs); | |||
| } | |||
| else if (NestType == NestType.Dictionary) | |||
| { | |||
| Dictionary<string, INestStructure<TOut>> outs = new Dictionary<string, INestStructure<TOut>>(); | |||
| foreach (var (key, value) in DictValue!) | |||
| { | |||
| outs.Add(key, value.AsNest().MapStructureInternal(func)); | |||
| } | |||
| return new Nest<TOut>(outs); | |||
| } | |||
| else | |||
| { | |||
| return Nest<TOut>.Empty; | |||
| } | |||
| } | |||
| public IEnumerator<T> GetEnumerator() | |||
| { | |||
| return Flatten().GetEnumerator(); | |||
| } | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| { | |||
| return GetEnumerator(); | |||
| } | |||
| public override string ToString() | |||
| { | |||
| StringBuilder sb = new StringBuilder(); | |||
| sb.Append("("); | |||
| WriteString(this, sb); | |||
| sb.Append(")"); | |||
| return sb.ToString(); | |||
| } | |||
| private static void WriteString(Nest<T> node, StringBuilder sb) | |||
| { | |||
| if (!string.IsNullOrEmpty(node.Name)) | |||
| { | |||
| sb.Append($"{node.Name}: "); | |||
| } | |||
| if (node.NestType == NestType.Node) | |||
| { | |||
| sb.Append(node.NodeValue!.ToString()); | |||
| } | |||
| else if (node.NestType == NestType.List) | |||
| { | |||
| sb.Append("["); | |||
| for(int i = 0; i < node.ListValue!.Count; i++) | |||
| { | |||
| WriteString(node.ListValue![i].AsNest(), sb); | |||
| if(i != node.ListValue!.Count - 1) | |||
| { | |||
| sb.Append(", "); | |||
| } | |||
| } | |||
| sb.Append("]"); | |||
| } | |||
| else if (node.NestType == NestType.Dictionary) | |||
| { | |||
| sb.Append("{"); | |||
| int count = node.DictValue!.Count; | |||
| int i = 0; | |||
| foreach (var (key, value) in node.DictValue!) | |||
| { | |||
| sb.Append($"{key}: "); | |||
| WriteString(value.AsNest(), sb); | |||
| if (i != count - 1) | |||
| { | |||
| sb.Append(", "); | |||
| } | |||
| i++; | |||
| } | |||
| sb.Append("}"); | |||
| } | |||
| else | |||
| { | |||
| sb.Append("<empty>"); | |||
| } | |||
| } | |||
| public static implicit operator Nest<T>((INestStructure<T>, INestStructure<T>) inputs) | |||
| { | |||
| return new Nest<T>(new INestStructure<T>[] { inputs.Item1, inputs.Item2 }); | |||
| } | |||
| public static implicit operator Nest<T>((INestStructure<T>, INestStructure<T>, INestStructure<T>) inputs) | |||
| { | |||
| return new Nest<T>(new INestStructure<T>[] { inputs.Item1, inputs.Item2, inputs.Item3 }); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,103 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| public class NestDictionary<TKey, TValue> : INestStructure<TValue>, IDictionary<TKey, TValue> where TKey : notnull | |||
| { | |||
| public NestType NestType => NestType.Dictionary; | |||
| public IDictionary<TKey, TValue> Value { get; set; } | |||
| public int ShallowNestedCount => Values.Count; | |||
| public int TotalNestedCount => Values.Count; | |||
| public NestDictionary(IDictionary<TKey, TValue> dict) | |||
| { | |||
| Value = dict; | |||
| } | |||
| public IEnumerable<TValue> Flatten() | |||
| { | |||
| return Value.Select(x => x.Value); | |||
| } | |||
| public INestStructure<TOut> MapStructure<TOut>(Func<TValue, TOut> func) | |||
| { | |||
| return new NestList<TOut>(Value.Select(x => func(x.Value))); | |||
| } | |||
| public Nest<TValue> AsNest() | |||
| { | |||
| return new Nest<TValue>(Value.Values.Select(x => new Nest<TValue>(x))); | |||
| } | |||
| // Required IDictionary<TKey, TValue> members | |||
| public int Count => Value.Count; | |||
| public bool IsReadOnly => Value.IsReadOnly; | |||
| public ICollection<TKey> Keys => Value.Keys; | |||
| public ICollection<TValue> Values => Value.Values; | |||
| public void Add(TKey key, TValue value) | |||
| { | |||
| Value.Add(key, value); | |||
| } | |||
| public void Add(KeyValuePair<TKey, TValue> item) | |||
| { | |||
| Value.Add(item); | |||
| } | |||
| public void Clear() | |||
| { | |||
| Value.Clear(); | |||
| } | |||
| public bool Contains(KeyValuePair<TKey, TValue> item) | |||
| { | |||
| return Value.Contains(item); | |||
| } | |||
| public bool ContainsKey(TKey key) | |||
| { | |||
| return Value.ContainsKey(key); | |||
| } | |||
| public void CopyTo(KeyValuePair<TKey, TValue>[] array, int arrayIndex) | |||
| { | |||
| Value.CopyTo(array, arrayIndex); | |||
| } | |||
| public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator() | |||
| { | |||
| return Value.GetEnumerator(); | |||
| } | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| { | |||
| return GetEnumerator(); | |||
| } | |||
| public bool Remove(TKey key) | |||
| { | |||
| return Value.Remove(key); | |||
| } | |||
| public bool Remove(KeyValuePair<TKey, TValue> item) | |||
| { | |||
| return Value.Remove(item); | |||
| } | |||
| public bool TryGetValue(TKey key, out TValue value) | |||
| { | |||
| return Value.TryGetValue(key, out value); | |||
| } | |||
| // Optional IDictionary<TKey, TValue> members | |||
| public TValue this[TKey key] | |||
| { | |||
| get => Value[key]; | |||
| set => Value[key] = value; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,53 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| /// <summary> | |||
| /// The implementation of a list that support nest structure, in which the depth is 1. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| public sealed class NestList<T> : INestStructure<T>, IEnumerable<T> | |||
| { | |||
| public NestType NestType => NestType.List; | |||
| public List<T> Values { get; set; } | |||
| public int ShallowNestedCount => Values.Count; | |||
| public int TotalNestedCount => Values.Count; | |||
| public NestList(params T[] values) | |||
| { | |||
| Values = new List<T>(values); | |||
| } | |||
| public NestList(IEnumerable<T> values) | |||
| { | |||
| Values = new List<T>(values); | |||
| } | |||
| public IEnumerable<T> Flatten() | |||
| { | |||
| return Values; | |||
| } | |||
| public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||
| { | |||
| return new NestList<TOut>(Values.Select(x => func(x))); | |||
| } | |||
| public Nest<T> AsNest() | |||
| { | |||
| return new Nest<T>(Values.Select(x => new Nest<T>(x))); | |||
| } | |||
| // Enumerator implementation | |||
| public IEnumerator<T> GetEnumerator() | |||
| { | |||
| return Values.GetEnumerator(); | |||
| } | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| { | |||
| return GetEnumerator(); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,36 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| /// <summary> | |||
| /// A nested structure with only one element. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| public class NestNode<T> : INestStructure<T> | |||
| { | |||
| public NestType NestType => NestType.Node; | |||
| public T Value { get; set; } | |||
| public int ShallowNestedCount => 1; | |||
| public int TotalNestedCount => 1; | |||
| public NestNode(T value) | |||
| { | |||
| Value = value; | |||
| } | |||
| public IEnumerable<T> Flatten() | |||
| { | |||
| yield return Value; | |||
| } | |||
| public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||
| { | |||
| return new NestNode<TOut>(func(Value)); | |||
| } | |||
| public Nest<T> AsNest() | |||
| { | |||
| return new Nest<T>(Value); | |||
| } | |||
| } | |||
| } | |||
| @@ -3,7 +3,7 @@ using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| namespace Tensorflow.Keras.Saving | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| public class TensorShapeConfig | |||
| { | |||
| @@ -161,8 +161,8 @@ namespace Tensorflow | |||
| break; | |||
| } | |||
| yield return (new Tensors(results.Take(FirstInputTensorCount)), results.Length == FirstInputTensorCount ? | |||
| null : new Tensors(results.Skip(FirstInputTensorCount))); | |||
| yield return (new Tensors(results.Take(FirstInputTensorCount).ToArray()), results.Length == FirstInputTensorCount ? | |||
| null : new Tensors(results.Skip(FirstInputTensorCount).ToArray())); | |||
| } | |||
| } | |||
| @@ -352,13 +352,19 @@ namespace Tensorflow.Eager | |||
| c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value)); | |||
| break; | |||
| case TF_AttrType.TF_ATTR_SHAPE: | |||
| var dims = (value as long[]).ToArray(); | |||
| long[] dims; | |||
| if (value is Shape shape) dims = shape.dims.ToArray(); | |||
| else if (value is long[] longs) dims = longs.ToArray(); | |||
| else if (value is int[] ints) dims = ints.Select(x => (long)x).ToArray(); | |||
| else dims = ((long[])value).ToArray(); | |||
| c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status); | |||
| status.Check(true); | |||
| break; | |||
| case TF_AttrType.TF_ATTR_FUNC: | |||
| if (value is ConcreteFunction func) | |||
| c_api.TFE_OpSetAttrFunctionName(op, key, func.func_graph.FuncName, func.func_graph.FuncName.Length); | |||
| else if(value is string str) | |||
| c_api.TFE_OpSetAttrFunctionName(op, key, str, str.Length); | |||
| else | |||
| throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC"); | |||
| break; | |||
| @@ -65,7 +65,7 @@ namespace Tensorflow.Eager | |||
| { | |||
| outgrad_vec = output_gradients.ToList(); | |||
| } | |||
| var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, false); | |||
| var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, true); | |||
| bool unconnected_gradients_zero = unconnected_gradients == "zero"; | |||
| @@ -137,7 +137,6 @@ namespace Tensorflow.Eager | |||
| { | |||
| dims[i] = c_api.TFE_TensorHandleDim(handle, i, status); | |||
| } | |||
| Shape tensor_shape = new(dims); | |||
| if(status.Code != TF_Code.TF_OK) | |||
| { | |||
| @@ -145,6 +144,7 @@ namespace Tensorflow.Eager | |||
| } | |||
| else | |||
| { | |||
| Shape tensor_shape = new(dims); | |||
| return new TapeTensor(id, dtype, tensor_shape); | |||
| } | |||
| } | |||
| @@ -173,8 +173,12 @@ namespace Tensorflow.Eager | |||
| return dtype == dtypes.variant || dtype == dtypes.resource; | |||
| } | |||
| bool ListContainNone(long[] list) | |||
| bool ListContainNone(long[]? list) | |||
| { | |||
| if(list is null) | |||
| { | |||
| return true; | |||
| } | |||
| int len = list.Length; | |||
| if(len == 0) | |||
| { | |||
| @@ -10,6 +10,11 @@ namespace Tensorflow.Eager | |||
| var str = NDArrayRender.ToString(nd); | |||
| return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}"; | |||
| } | |||
| public string ToString(int maxLength) | |||
| { | |||
| var nd = new NDArray(this); | |||
| var str = NDArrayRender.ToString(nd, maxLength); | |||
| return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}"; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,25 @@ | |||
| using Tensorflow; | |||
| internal static class GraphOnlyOps | |||
| { | |||
| /// <summary> | |||
| /// Graph-only version of tf.compat.v1.placeholder(), for internal use only. | |||
| /// </summary> | |||
| /// <param name="dtyype"></param> | |||
| /// <param name="shape"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| internal static Tensor graph_placeholder(TF_DataType dtype, Shape shape, string? name = null) | |||
| { | |||
| var dtype_value = new AttrValue() { Type = dtype.as_datatype_enum() }; | |||
| var shape_value = new AttrValue() { Shape = shape.as_proto() }; | |||
| var g = ops.get_default_graph(); | |||
| Dictionary<string, AttrValue> attrs = new(); | |||
| attrs["dtype"] = dtype_value; | |||
| attrs["shape"] = shape_value; | |||
| var op = g.create_op("Placeholder", new Tensor[0], new TF_DataType[] { dtype }, | |||
| new TF_DataType[0], attrs: attrs, name: name); | |||
| var result = op.outputs[0]; | |||
| return result; | |||
| } | |||
| } | |||
| @@ -0,0 +1,19 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Exceptions | |||
| { | |||
| public class NotOkStatusException : TensorflowException | |||
| { | |||
| public NotOkStatusException() : base() | |||
| { | |||
| } | |||
| public NotOkStatusException(string message) : base(message) | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -49,12 +49,25 @@ namespace Tensorflow.Framework | |||
| public static implicit operator Tensor(IndexedSlices indexedSlices) | |||
| { | |||
| return indexedSlices.values; | |||
| return _indexed_slices_to_tensor(indexedSlices); | |||
| } | |||
| public static implicit operator IndexedSlices(Tensor tensor) | |||
| { | |||
| return tensor.Tag as IndexedSlices; | |||
| } | |||
| /// <summary> | |||
| /// Converts an IndexedSlices object `value` to a Tensor. | |||
| /// </summary> | |||
| /// <param name="indexedSlices"></param> | |||
| /// <param name="dtype"></param> | |||
| /// <param name="name"></param> | |||
| /// <param name="as_ref"></param> | |||
| /// <returns></returns> | |||
| public static Tensor _indexed_slices_to_tensor(IndexedSlices indexedSlices, TF_DataType dtype = TF_DataType.DtInvalid, String name = "", bool as_ref = false) | |||
| { | |||
| return gen_math_ops.unsorted_segment_sum(indexedSlices.values, indexedSlices.indices, indexedSlices.dense_shape.slice(0)); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,4 +1,5 @@ | |||
| using System.Linq; | |||
| using Tensorflow.Eager; | |||
| namespace Tensorflow.Framework.Models | |||
| { | |||
| @@ -24,5 +25,17 @@ namespace Tensorflow.Framework.Models | |||
| shapes.Insert(0, dim); | |||
| return new TensorSpec(shapes.ToArray(), _dtype); | |||
| } | |||
| public static TensorSpec FromTensor(Tensor tensor, string? name = null) | |||
| { | |||
| if(tensor is EagerTensor) | |||
| { | |||
| return new TensorSpec(tensor.shape, tensor.dtype, name); | |||
| } | |||
| else | |||
| { | |||
| return new TensorSpec(tensor.shape, tensor.dtype, name ?? tensor.name); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,89 @@ | |||
| using Tensorflow.Graphs; | |||
| namespace Tensorflow.Framework | |||
| { | |||
| internal static class auto_control_deps_utils | |||
| { | |||
| public static readonly string READ_ONLY_RESOURCE_INPUTS_ATTR = "_read_only_resource_inputs"; | |||
| public static List<int> get_read_only_resource_input_indices_graph(FuncGraph func_graph) | |||
| { | |||
| List<int> result = new List<int>(); | |||
| // A cache to store the read only resource inputs of an Op. | |||
| // Operation -> ObjectIdentitySet of resource handles. | |||
| Dictionary<Operation, HashSet<Tensor>> opReadOnlyResourceInputs = | |||
| new Dictionary<Operation, HashSet<Tensor>>(); | |||
| for (int inputIndex = 0; inputIndex < func_graph.Inputs.Length; inputIndex++) | |||
| { | |||
| Tensor t = func_graph.Inputs[inputIndex]; | |||
| if (t.dtype != dtypes.resource) | |||
| continue; | |||
| bool readOnly = true; | |||
| foreach (var op in t.consumers()) | |||
| { | |||
| if (opReadOnlyResourceInputs.ContainsKey(op)) | |||
| { | |||
| if (!opReadOnlyResourceInputs[op].Contains(t)) | |||
| { | |||
| readOnly = false; | |||
| break; | |||
| } | |||
| } | |||
| else | |||
| { | |||
| List<int> indices = _get_read_only_resource_input_indices_op(op); | |||
| opReadOnlyResourceInputs[op] = new HashSet<Tensor>( | |||
| indices.Select(i => op.inputs[i])); | |||
| if (!opReadOnlyResourceInputs[op].Contains(t)) | |||
| { | |||
| readOnly = false; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| if (readOnly) | |||
| result.Add(inputIndex); | |||
| } | |||
| return result; | |||
| } | |||
| private static List<int> _get_read_only_resource_input_indices_op(Operation op) | |||
| { | |||
| // ignore the RESOURCE_READ_OPS | |||
| int[] read_only_input_indices; | |||
| try | |||
| { | |||
| read_only_input_indices = op.get_attr<int[]>(READ_ONLY_RESOURCE_INPUTS_ATTR); | |||
| } | |||
| catch (InvalidArgumentError) | |||
| { | |||
| return new List<int>(); | |||
| } | |||
| int read_only_index = 0; | |||
| List<int> result = new(); | |||
| for (int i = 0; i < op.inputs.Length; i++) | |||
| { | |||
| if (read_only_index >= read_only_input_indices.Length) | |||
| { | |||
| break; | |||
| } | |||
| if (op.inputs[i].dtype != dtypes.resource) | |||
| { | |||
| continue; | |||
| } | |||
| if (read_only_index < read_only_input_indices.Length && i == read_only_input_indices[read_only_index]) | |||
| { | |||
| result.Add(i); | |||
| read_only_index++; | |||
| } | |||
| } | |||
| return result; | |||
| } | |||
| } | |||
| } | |||
| @@ -42,10 +42,10 @@ namespace Tensorflow.Framework | |||
| func_graph.as_default(); | |||
| importer.import_graph_def(graph_def, name: "", validate_colocation_constraints: false); | |||
| var input_tensor_names = fdef.Signature.InputArg.Select(x => nested_to_flat_tensor_name[x.Name]); | |||
| func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); | |||
| func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x)).ToArray()); | |||
| var output_tensor_names = fdef.Signature.OutputArg.Select(x => nested_to_flat_tensor_name[fdef.Ret[x.Name]]); | |||
| func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); | |||
| func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x)).ToArray()); | |||
| // TODO(Rinne): func_graph.ControlOutputs | |||
| _set_handle_data(func_graph, fdef); | |||
| @@ -8,6 +8,7 @@ using Tensorflow.Gradients; | |||
| using Tensorflow.Graphs; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Util; | |||
| using Tensorflow.Common.Extensions; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Functions | |||
| @@ -40,6 +41,18 @@ namespace Tensorflow.Functions | |||
| public Tensor[] FlatStructuredOutputs => func_graph.FlatStructuredOutputs; | |||
| public IEnumerable<IVariableV1> Variables => func_graph.Variables; | |||
| public IEnumerable<IVariableV1> TrainableVariables => func_graph.TrainableVariables; | |||
| internal NameAttrList AsNameAttrList | |||
| { | |||
| get | |||
| { | |||
| NameAttrList ret = new() { Name = this.Name }; | |||
| foreach (var (name, value) in _attrs) | |||
| { | |||
| ret.Attr[name] = value; | |||
| } | |||
| return ret; | |||
| } | |||
| } | |||
| public ConcreteFunction(string name) | |||
| { | |||
| @@ -3,4 +3,7 @@ global using System.Collections.Generic; | |||
| global using System.Text; | |||
| global using System.Collections; | |||
| global using System.Data; | |||
| global using System.Linq; | |||
| global using System.Linq; | |||
| global using Tensorflow.Keras.Engine; | |||
| global using Tensorflow.Framework.Models; | |||
| global using static Tensorflow.Binding; | |||
| @@ -90,8 +90,7 @@ namespace Tensorflow.Gradients | |||
| ? input_values[0].rank + dim_int | |||
| : dim_int % input_values[0].rank; | |||
| var sizes = input_values.Select(x => x.shape[non_neg_concat_dim]).ToArray(); | |||
| var sizes_tensor = constant_op.constant(sizes); | |||
| out_grads = array_ops.split(grad, sizes_tensor, non_neg_concat_dim).ToList(); | |||
| out_grads = array_ops.split(grad, sizes.Select(x => (int)x).ToArray(), ops.convert_to_tensor(non_neg_concat_dim)).ToList(); | |||
| } | |||
| else if (constant_op.is_constant(concat_dim)) | |||
| { | |||
| @@ -127,7 +126,7 @@ namespace Tensorflow.Gradients | |||
| new Tensor[] { non_neg_concat_dim, tf.constant(0) }, | |||
| new Tensor[] { tf.constant(1), tf.constant(-1) }); | |||
| var squeeze_sizes = array_ops.squeeze(slice); | |||
| out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_split: (int)non_neg_concat_dim).ToList(); | |||
| out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_or_size_splits: (int)non_neg_concat_dim).ToList(); | |||
| } | |||
| else | |||
| { | |||
| @@ -374,5 +373,13 @@ namespace Tensorflow.Gradients | |||
| var p = op.inputs[1]; | |||
| return new Tensor[] { array_ops.transpose(grads[0], array_ops.invert_permutation(p)), null }; | |||
| } | |||
| [RegisterGradient("ReverseV2")] | |||
| public static Tensor[] _ReverseV2Grad(Operation op, Tensor[] grads) | |||
| { | |||
| var grad = grads[0]; | |||
| var axis = op.inputs[1]; | |||
| return new Tensor[] { array_ops.reverse(grad, axis), null }; | |||
| } | |||
| } | |||
| } | |||
| @@ -117,6 +117,137 @@ namespace Tensorflow.Gradients | |||
| }; | |||
| } | |||
| public static string ellipsis = "..."; | |||
| [RegisterGradient("Einsum")] | |||
| public static Tensor[] _EinsumGrad(Operation op, Tensor[] grads) | |||
| { | |||
| // Gradient for Einsum. | |||
| string equation = (string)op.get_attr("equation"); | |||
| string[] split_equation = equation.Split(new string[] { "->" }, StringSplitOptions.None); | |||
| var input_subs = split_equation[0]; | |||
| var output_subs = split_equation[1]; | |||
| if (op.inputs.Length == 1) | |||
| { | |||
| var input_shape = array_ops.shape(op.inputs[0]); | |||
| var reduced_label_set = new HashSet<char>(new HashSet<char>(input_subs).Except(new HashSet<char>(output_subs + ellipsis))); | |||
| if (reduced_label_set.Count == 0) | |||
| return new Tensor[] { math_ops.einsum(string.Format("{0}->{1}", output_subs, input_subs), new Tensors(grads)) }; | |||
| return new Tensor[] { _GetGradReduced(new Tensors(grads), output_subs, input_subs, input_shape, reduced_label_set) }; | |||
| } | |||
| string[] split_input_subs = input_subs.Split(new string[] { "," }, StringSplitOptions.None); | |||
| var x_subs = split_input_subs[0]; | |||
| var y_subs = split_input_subs[1]; | |||
| // Add ellipsis for broadcasted dimensions if any operand does not have it. | |||
| // This is because the equation "...ij,jk->ik" may be valid if the 0th input's | |||
| // batch shape is empty, but the VJP equation "jk,ik->...ij" is not valid | |||
| // because only the output subscripts contain ellipsis. | |||
| if (output_subs.Contains(ellipsis)) | |||
| { | |||
| if (!x_subs.Contains(ellipsis)) | |||
| x_subs += ellipsis; | |||
| if (!y_subs.Contains(ellipsis)) | |||
| y_subs += ellipsis; | |||
| } | |||
| // Obtain the gradients wrt the inputs x and y, without taking into account | |||
| // the unbroadcasting. | |||
| var x = op.inputs[0]; | |||
| var y = op.inputs[1]; | |||
| if (grads.GetDataType().is_complex()) | |||
| { | |||
| x = math_ops.conj(x); | |||
| y = math_ops.conj(y); | |||
| } | |||
| var x_shape = array_ops.shape(x); | |||
| var y_shape = array_ops.shape(y); | |||
| var grad_x = _GetGradWrt(grads, y, x_shape, x_subs, y_subs, output_subs); | |||
| var grad_y = _GetGradWrt(grads, x, y_shape, y_subs, x_subs, output_subs); | |||
| if (!output_subs.Contains(ellipsis)) | |||
| return new Tensor[] { grad_x, grad_y }; | |||
| var bx = _GetBcastSubshape(x_subs); | |||
| int bx_start = bx[0], bx_end = bx[1]; | |||
| var by = _GetBcastSubshape(y_subs); | |||
| int by_start = by[0], by_end = by[1]; | |||
| var x_shape_static = x.shape; | |||
| var y_shape_static = y.shape; | |||
| if(x_shape_static.IsFullyDefined && | |||
| y_shape_static.IsFullyDefined && | |||
| x_shape_static[string.Format("{0}:{1}",bx_start,bx_end)] == y_shape_static[string.Format("{0}:{1}", by_start, by_end)]) | |||
| return new Tensor[] { grad_x, grad_y }; | |||
| var r = gen_array_ops.broadcast_gradient_args(x_shape[string.Format("{0}:{1}", bx_start, bx_end)], | |||
| y_shape[string.Format("{0}:{1}", by_start, by_end)]); | |||
| var rx = r[0]; | |||
| var ry = r[1]; | |||
| grad_x = array_ops.reshape(math_ops.reduce_sum(grad_x, bx_start + rx), x_shape); | |||
| grad_y = array_ops.reshape(math_ops.reduce_sum(grad_y, by_start + ry), y_shape); | |||
| return new Tensor[] { grad_x, grad_y }; | |||
| } | |||
| protected static Tensor _GetGradWrt(Tensor[] output_grads, Tensor other_operand, Tensor input_shape, | |||
| string input_subs, string other_subs, string output_subs) | |||
| { | |||
| var reduced_label_set = new HashSet<char>(new HashSet<char>(input_subs).Except(new HashSet<char>(output_subs + other_subs + "."))); | |||
| var left_subs = string.Join("", input_subs.Where(s => !reduced_label_set.Contains(s))); | |||
| var grad_reduced = math_ops.einsum(string.Format("{0},{1}->{2}", output_subs, other_subs, left_subs), new Tensors((Tensors)output_grads, other_operand)); | |||
| if (reduced_label_set.Count == 0) | |||
| return grad_reduced; | |||
| return _GetGradReduced(grad_reduced, left_subs, input_subs, input_shape, reduced_label_set); | |||
| } | |||
| protected static Tensor _GetGradReduced(Tensor output_grad, string output_subs, string input_subs, Tensor input_shape, HashSet<char> reduced_label_set) | |||
| { | |||
| string reduced_subs; | |||
| Tensor reduced_dims; | |||
| List<int> reduced_axes; | |||
| _GetReducedSubscripts(reduced_label_set, input_shape, input_subs, out reduced_subs, out reduced_dims, out reduced_axes); | |||
| bool has_repeated_labels = ( | |||
| new HashSet<char>(input_subs).Count + new HashSet<char>(output_subs).Count < | |||
| input_subs.Length + output_subs.Length); | |||
| var input_subs_without_reduced_labels = string.Join("", input_subs.Where(s => !reduced_label_set.Contains(s))); | |||
| if (!has_repeated_labels && input_subs_without_reduced_labels == output_subs) | |||
| { | |||
| var reduced_shape = math_ops.reduced_shape(input_shape, ops.convert_to_tensor(reduced_axes)); | |||
| return gen_array_ops.broadcast_to(array_ops.reshape(output_grad, reduced_shape), input_shape); | |||
| } | |||
| else | |||
| { | |||
| var grad_shape_with_reduced_labels = array_ops.concat(new Tensor[] { reduced_dims, array_ops.shape(new Tensors(output_grad)) }, axis: 0); | |||
| var reduced_shape = array_ops.concat(new Tensor[] { array_ops.ones(reduced_label_set.Count, dtype: dtypes.int32), array_ops.shape(new Tensors(output_grad)) }, axis: 0); | |||
| var broadcasted_grad = gen_array_ops.broadcast_to(array_ops.reshape(output_grad, reduced_shape), grad_shape_with_reduced_labels); | |||
| return math_ops.einsum(string.Format("{0}->{1}", reduced_subs + output_subs, input_subs), new Tensors(broadcasted_grad)); | |||
| } | |||
| } | |||
| protected static void _GetReducedSubscripts(HashSet<char> reduced_label_set, Tensor input_shape, string subscripts, out string reduced_subs, out Tensor reduced_dims, out List<int> reduced_axes) | |||
| { | |||
| reduced_subs = string.Join("", reduced_label_set.Select(c => c.ToString())); | |||
| reduced_axes = reduced_subs.Select(s => _GetAxisFromLabel(subscripts, s)).ToList(); | |||
| reduced_dims = array_ops.stack(reduced_axes.Select(ax => input_shape[ax]).ToList()); | |||
| } | |||
| protected static int _GetAxisFromLabel(string subscripts, char label) | |||
| { | |||
| var splits = subscripts.Split(new string[] { ellipsis }, StringSplitOptions.None); | |||
| var index = splits[0].IndexOf(label); | |||
| if (index != -1) return index; | |||
| if (splits.Length < 2) throw new OutOfRangeError(); | |||
| index = splits[1].IndexOf(label); | |||
| if (index != -1) return index; | |||
| throw new ValueError(); | |||
| } | |||
| protected static int[] _GetBcastSubshape(string subscripts) | |||
| { | |||
| int start = subscripts.IndexOf(ellipsis); | |||
| if (start == -1) return new int[] { 0, 0 }; | |||
| int remaining = subscripts.Length - (start + ellipsis.Length); | |||
| int end; | |||
| if (remaining > 0) end = remaining; | |||
| else throw new Exception(); | |||
| return new int[] { start, end }; | |||
| } | |||
| /// <summary> | |||
| /// Returns grad * exp(x). | |||
| /// </summary> | |||
| @@ -365,6 +365,23 @@ namespace Tensorflow.Gradients | |||
| }; | |||
| } | |||
| [RegisterGradient("AvgPool")] | |||
| public static Tensor[] _AvgPoolGrad(Operation op, Tensor[] grads) | |||
| { | |||
| Tensor grad = grads[0]; | |||
| return new Tensor[] | |||
| { | |||
| gen_nn_ops.avg_pool_grad( | |||
| array_ops.shape(op.inputs[0]), | |||
| grad, | |||
| op.get_attr_list<int>("ksize"), | |||
| op.get_attr_list<int>("strides"), | |||
| op.get_attr<string>("padding"), | |||
| op.get_attr<string>("data_format")) | |||
| }; | |||
| } | |||
| /// <summary> | |||
| /// Return the gradients for TopK. | |||
| /// </summary> | |||
| @@ -81,7 +81,7 @@ public class FuncGraph : Graph, IDisposable | |||
| public IEnumerable<IVariableV1> TrainableVariables => Variables.Where(v => v.Trainable); | |||
| public Dictionary<string, AttrValue> Attrs { get; set; } | |||
| Dictionary<long, (Tensor, Tensor)> _captures | |||
| internal Dictionary<long, (Tensor, Tensor)> _captures | |||
| = new Dictionary<long, (Tensor, Tensor)>(); | |||
| public Tensor[] external_captures | |||
| @@ -399,7 +399,7 @@ public class FuncGraph : Graph, IDisposable | |||
| var flat_func_args = nest.flatten(func_args as object); | |||
| var flat_func_kwargs = nest.flatten(func_kwargs as object); | |||
| func_graph.Inputs = new Tensors(flat_func_args.concat(flat_func_kwargs) | |||
| .Where(x => x is Tensor).Select(x => (Tensor)x)); | |||
| .Where(x => x is Tensor).Select(x => (Tensor)x).ToArray()); | |||
| //var func_args_before = nest.pack_sequence_as(func_args, flat_func_args, true); | |||
| //var func_kwargs_before = nest.pack_sequence_as(func_kwargs, flat_func_kwargs, true); | |||
| @@ -544,12 +544,12 @@ public class FuncGraph : Graph, IDisposable | |||
| Tensor placeholder; | |||
| try | |||
| { | |||
| placeholder = tf.placeholder(tensor.dtype, tensor.shape, name); | |||
| placeholder = GraphOnlyOps.graph_placeholder(tensor.dtype, tensor.shape, name); | |||
| } | |||
| catch (ValueError) | |||
| catch (ValueError ex) | |||
| { | |||
| // TODO(Rinne): Add warning here. | |||
| placeholder = tf.placeholder(tensor.dtype, tensor.shape); | |||
| tf.Logger.Warning(ex.ToString()); | |||
| placeholder = GraphOnlyOps.graph_placeholder(tensor.dtype, tensor.shape); | |||
| } | |||
| handle_data_util.copy_handle_data(tensor, placeholder); | |||
| if (name is not null) | |||
| @@ -575,12 +575,12 @@ public class FuncGraph : Graph, IDisposable | |||
| Tensor placeholder; | |||
| try | |||
| { | |||
| placeholder = tf.placeholder(spec.dtype, spec.shape, requested_name); | |||
| placeholder = GraphOnlyOps.graph_placeholder(spec.dtype, spec.shape, requested_name); | |||
| } | |||
| catch (ValueError) | |||
| { | |||
| // TODO(Rinne): Add warning here. | |||
| placeholder = tf.placeholder(spec.dtype, spec.shape); | |||
| placeholder = GraphOnlyOps.graph_placeholder(spec.dtype, spec.shape); | |||
| } | |||
| if (name is not null) | |||
| { | |||
| @@ -129,7 +129,7 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| protected Graph outer_graph; | |||
| internal Graph outer_graph; | |||
| public Graph OuterGraph => outer_graph; | |||
| public Dictionary<string, EagerDefinedFunction> Functions => _functions; | |||
| public SafeGraphHandle c_graph => _handle; | |||
| @@ -0,0 +1,10 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class ExponentialArgs : LayerArgs | |||
| { | |||
| } | |||
| } | |||
| @@ -0,0 +1,10 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class HardSigmoidArgs : LayerArgs | |||
| { | |||
| } | |||
| } | |||
| @@ -0,0 +1,11 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class SELUArgs : LayerArgs | |||
| { | |||
| } | |||
| } | |||
| @@ -0,0 +1,10 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class SoftplusArgs : LayerArgs | |||
| { | |||
| } | |||
| } | |||
| @@ -0,0 +1,10 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class SoftsignArgs : LayerArgs | |||
| { | |||
| } | |||
| } | |||
| @@ -0,0 +1,10 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class SwishArgs : LayerArgs | |||
| { | |||
| } | |||
| } | |||
| @@ -0,0 +1,10 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class TanhArgs : LayerArgs | |||
| { | |||
| } | |||
| } | |||
| @@ -0,0 +1,10 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class Conv2DTransposeArgs : Conv2DArgs | |||
| { | |||
| } | |||
| } | |||
| @@ -0,0 +1,10 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class AddArgs : MergeArgs | |||
| { | |||
| } | |||
| } | |||
| @@ -0,0 +1,10 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class ConcatenateArgs : MergeArgs | |||
| { | |||
| } | |||
| } | |||
| @@ -0,0 +1,10 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class SubtractArgs : MergeArgs | |||
| { | |||
| } | |||
| } | |||
| @@ -0,0 +1,10 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class GlobalAveragePooling1DArgs : Pooling1DArgs | |||
| { | |||
| } | |||
| } | |||
| @@ -0,0 +1,10 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class GlobalAveragePooling2DArgs : Pooling2DArgs | |||
| { | |||
| } | |||
| } | |||
| @@ -0,0 +1,10 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class GlobalMaxPooling1DArgs : Pooling1DArgs | |||
| { | |||
| } | |||
| } | |||
| @@ -0,0 +1,10 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class GlobalMaxPooling2DArgs : Pooling2DArgs | |||
| { | |||
| } | |||
| } | |||
| @@ -0,0 +1,10 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class MaxPooling1DArgs : Pooling1DArgs | |||
| { | |||
| } | |||
| } | |||
| @@ -7,7 +7,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||
| [JsonProperty("size")] | |||
| public Shape Size { get; set; } | |||
| [JsonProperty("data_format")] | |||
| public string DataFormat { get; set; } | |||
| public string DataFormat { get; set; } = "channels_last"; | |||
| /// <summary> | |||
| /// 'nearest', 'bilinear' | |||
| /// </summary> | |||
| @@ -0,0 +1,10 @@ | |||
| using Newtonsoft.Json; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class UpSampling1DArgs : AutoSerializeLayerArgs | |||
| { | |||
| [JsonProperty("size")] | |||
| public int Size { get; set; } | |||
| } | |||
| } | |||
| @@ -0,0 +1,20 @@ | |||
| using Newtonsoft.Json; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.NumPy; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class BidirectionalArgs : AutoSerializeLayerArgs | |||
| { | |||
| [JsonProperty("layer")] | |||
| public ILayer Layer { get; set; } | |||
| [JsonProperty("merge_mode")] | |||
| public string? MergeMode { get; set; } | |||
| [JsonProperty("backward_layer")] | |||
| public ILayer BackwardLayer { get; set; } | |||
| public NDArray Weights { get; set; } | |||
| } | |||
| } | |||
| @@ -0,0 +1,29 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class GRUArgs : AutoSerializeLayerArgs | |||
| { | |||
| public int Units { get; set; } | |||
| public Activation Activation { get; set; } | |||
| public Activation RecurrentActivation { get; set; } | |||
| public bool UseBias { get; set; } = true; | |||
| public float Dropout { get; set; } = .0f; | |||
| public float RecurrentDropout { get; set; } = .0f; | |||
| public IInitializer KernelInitializer { get; set; } | |||
| public IInitializer RecurrentInitializer { get; set; } | |||
| public IInitializer BiasInitializer { get; set; } | |||
| public bool ReturnSequences { get;set; } | |||
| public bool ReturnState { get;set; } | |||
| public bool GoBackwards { get;set; } | |||
| public bool Stateful { get;set; } | |||
| public bool Unroll { get;set; } | |||
| public bool TimeMajor { get;set; } | |||
| public bool ResetAfter { get;set; } | |||
| public int Implementation { get; set; } = 2; | |||
| } | |||
| } | |||
| @@ -0,0 +1,39 @@ | |||
| using Newtonsoft.Json; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class GRUCellArgs : AutoSerializeLayerArgs | |||
| { | |||
| [JsonProperty("units")] | |||
| public int Units { get; set; } | |||
| // TODO(Rinne): lack of initialized value of Activation. Merging keras | |||
| // into tf.net could resolve it. | |||
| [JsonProperty("activation")] | |||
| public Activation Activation { get; set; } | |||
| [JsonProperty("recurrent_activation")] | |||
| public Activation RecurrentActivation { get; set; } | |||
| [JsonProperty("use_bias")] | |||
| public bool UseBias { get; set; } = true; | |||
| [JsonProperty("dropout")] | |||
| public float Dropout { get; set; } = .0f; | |||
| [JsonProperty("recurrent_dropout")] | |||
| public float RecurrentDropout { get; set; } = .0f; | |||
| [JsonProperty("kernel_initializer")] | |||
| public IInitializer KernelInitializer { get; set; } | |||
| [JsonProperty("recurrent_initializer")] | |||
| public IInitializer RecurrentInitializer { get; set; } | |||
| [JsonProperty("bias_initializer")] | |||
| public IInitializer BiasInitializer { get; set; } | |||
| [JsonProperty("reset_after")] | |||
| public bool ResetAfter { get;set; } | |||
| [JsonProperty("implementation")] | |||
| public int Implementation { get; set; } = 2; | |||
| } | |||
| } | |||
| @@ -0,0 +1,13 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class GRUOptionalArgs | |||
| { | |||
| public string Identifier => "GRU"; | |||
| public Tensor Mask { get; set; } = null; | |||
| } | |||
| } | |||
| @@ -1,11 +1,14 @@ | |||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class LSTMArgs : RNNArgs | |||
| { | |||
| // TODO: maybe change the `RNNArgs` and implement this class. | |||
| public bool UnitForgetBias { get; set; } | |||
| public float Dropout { get; set; } | |||
| public float RecurrentDropout { get; set; } | |||
| public int Implementation { get; set; } | |||
| public LSTMArgs Clone() | |||
| { | |||
| return (LSTMArgs)MemberwiseClone(); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,7 +1,35 @@ | |||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
| using Newtonsoft.Json; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| // TODO: complete the implementation | |||
| public class LSTMCellArgs : LayerArgs | |||
| public class LSTMCellArgs : AutoSerializeLayerArgs | |||
| { | |||
| [JsonProperty("units")] | |||
| public int Units { get; set; } | |||
| // TODO(Rinne): lack of initialized value of Activation. Merging keras | |||
| // into tf.net could resolve it. | |||
| [JsonProperty("activation")] | |||
| public Activation Activation { get; set; } | |||
| [JsonProperty("recurrent_activation")] | |||
| public Activation RecurrentActivation { get; set; } | |||
| [JsonProperty("use_bias")] | |||
| public bool UseBias { get; set; } = true; | |||
| [JsonProperty("dropout")] | |||
| public float Dropout { get; set; } = .0f; | |||
| [JsonProperty("recurrent_dropout")] | |||
| public float RecurrentDropout { get; set; } = .0f; | |||
| [JsonProperty("kernel_initializer")] | |||
| public IInitializer KernelInitializer { get; set; } | |||
| [JsonProperty("recurrent_initializer")] | |||
| public IInitializer RecurrentInitializer { get; set; } | |||
| [JsonProperty("bias_initializer")] | |||
| public IInitializer BiasInitializer { get; set; } | |||
| [JsonProperty("unit_forget_bias")] | |||
| public bool UnitForgetBias { get; set; } = true; | |||
| [JsonProperty("implementation")] | |||
| public int Implementation { get; set; } = 2; | |||
| } | |||
| } | |||
| @@ -1,17 +1,12 @@ | |||
| using Newtonsoft.Json; | |||
| using System.Collections.Generic; | |||
| using Tensorflow.Keras.Layers; | |||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| // TODO(Rinne): add regularizers. | |||
| public class RNNArgs : AutoSerializeLayerArgs | |||
| { | |||
| public interface IRnnArgCell : ILayer | |||
| { | |||
| object state_size { get; } | |||
| } | |||
| [JsonProperty("cell")] | |||
| // TODO: the cell should be serialized with `serialize_keras_object`. | |||
| public IRnnArgCell Cell { get; set; } = null; | |||
| [JsonProperty("return_sequences")] | |||
| public bool ReturnSequences { get; set; } = false; | |||
| [JsonProperty("return_state")] | |||
| @@ -24,31 +19,31 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
| public bool Unroll { get; set; } = false; | |||
| [JsonProperty("time_major")] | |||
| public bool TimeMajor { get; set; } = false; | |||
| // TODO: Add `num_constants` and `zero_output_for_mask`. | |||
| public Dictionary<string, object> Kwargs { get; set; } = null; | |||
| public int? InputDim { get; set; } | |||
| public int? InputLength { get; set; } | |||
| // TODO: Add `num_constants` and `zero_output_for_mask`. | |||
| [JsonProperty("units")] | |||
| public int Units { get; set; } | |||
| [JsonProperty("activation")] | |||
| public Activation Activation { get; set; } | |||
| [JsonProperty("recurrent_activation")] | |||
| public Activation RecurrentActivation { get; set; } | |||
| [JsonProperty("use_bias")] | |||
| public bool UseBias { get; set; } = true; | |||
| public IInitializer KernelInitializer { get; set; } | |||
| public IInitializer RecurrentInitializer { get; set; } | |||
| public IInitializer BiasInitializer { get; set; } | |||
| [JsonProperty("dropout")] | |||
| public float Dropout { get; set; } = .0f; | |||
| [JsonProperty("zero_output_for_mask")] | |||
| public bool ZeroOutputForMask { get; set; } = false; | |||
| [JsonProperty("recurrent_dropout")] | |||
| public float RecurrentDropout { get; set; } = .0f; | |||
| // kernel_regularizer=None, | |||
| // recurrent_regularizer=None, | |||
| // bias_regularizer=None, | |||
| // activity_regularizer=None, | |||
| // kernel_constraint=None, | |||
| // recurrent_constraint=None, | |||
| // bias_constraint=None, | |||
| // dropout=0., | |||
| // recurrent_dropout=0., | |||
| // return_sequences=False, | |||
| // return_state=False, | |||
| // go_backwards=False, | |||
| // stateful=False, | |||
| // unroll=False, | |||
| // **kwargs): | |||
| public RNNArgs Clone() | |||
| { | |||
| return (RNNArgs)MemberwiseClone(); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,14 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class RnnOptionalArgs: IOptionalArgs | |||
| { | |||
| public string Identifier => "Rnn"; | |||
| public Tensor Mask { get; set; } = null; | |||
| public Tensors Constants { get; set; } = null; | |||
| } | |||
| } | |||
| @@ -1,4 +1,4 @@ | |||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class SimpleRNNArgs : RNNArgs | |||
| { | |||
| @@ -0,0 +1,27 @@ | |||
| using Newtonsoft.Json; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class SimpleRNNCellArgs: AutoSerializeLayerArgs | |||
| { | |||
| [JsonProperty("units")] | |||
| public int Units { get; set; } | |||
| // TODO(Rinne): lack of initialized value of Activation. Merging keras | |||
| // into tf.net could resolve it. | |||
| [JsonProperty("activation")] | |||
| public Activation Activation { get; set; } | |||
| [JsonProperty("use_bias")] | |||
| public bool UseBias { get; set; } = true; | |||
| [JsonProperty("dropout")] | |||
| public float Dropout { get; set; } = .0f; | |||
| [JsonProperty("recurrent_dropout")] | |||
| public float RecurrentDropout { get; set; } = .0f; | |||
| [JsonProperty("kernel_initializer")] | |||
| public IInitializer KernelInitializer { get; set; } | |||
| [JsonProperty("recurrent_initializer")] | |||
| public IInitializer RecurrentInitializer { get; set; } | |||
| [JsonProperty("bias_initializer")] | |||
| public IInitializer BiasInitializer { get; set; } | |||
| } | |||
| } | |||
| @@ -1,10 +1,10 @@ | |||
| using System.Collections.Generic; | |||
| using Tensorflow.Keras.Layers; | |||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class StackedRNNCellsArgs : LayerArgs | |||
| { | |||
| public IList<RnnCell> Cells { get; set; } | |||
| public Dictionary<string, object> Kwargs { get; set; } = null; | |||
| public bool ReverseStateOrder = false; | |||
| } | |||
| } | |||
| @@ -0,0 +1,24 @@ | |||
| using Newtonsoft.Json; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class WrapperArgs : AutoSerializeLayerArgs | |||
| { | |||
| [JsonProperty("layer")] | |||
| public ILayer Layer { get; set; } | |||
| public WrapperArgs(ILayer layer) | |||
| { | |||
| Layer = layer; | |||
| } | |||
| public static implicit operator WrapperArgs(BidirectionalArgs args) | |||
| => new WrapperArgs(args.Layer); | |||
| } | |||
| } | |||
| @@ -14,6 +14,9 @@ public interface ICallback | |||
| void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs); | |||
| void on_predict_end(); | |||
| void on_test_begin(); | |||
| void on_test_end(Dictionary<string, float> logs); | |||
| void on_test_batch_begin(long step); | |||
| void on_test_batch_end(long end_step, Dictionary<string, float> logs); | |||
| } | |||
| @@ -60,7 +60,7 @@ public interface IModel : ILayer | |||
| bool skip_mismatch = false, | |||
| object options = null); | |||
| Dictionary<string, float> evaluate(Tensor x, Tensor y, | |||
| Dictionary<string, float> evaluate(NDArray x, NDArray y, | |||
| int batch_size = -1, | |||
| int verbose = 1, | |||
| int steps = -1, | |||
| @@ -0,0 +1,75 @@ | |||
| namespace Tensorflow.Keras.Engine; | |||
| /// <summary> | |||
| /// A representation of a Keras in/output during Functional API construction. | |||
| /// </summary> | |||
| public class KerasTensor | |||
| { | |||
| private Tensors _original_tensors; | |||
| public Tensors original_tensors | |||
| { | |||
| get => _original_tensors; | |||
| set => _original_tensors = value; | |||
| } | |||
| private Shape _inferred_value; | |||
| public Shape inferred_value => _inferred_value; | |||
| private string _name; | |||
| private TensorSpec _type_spec; | |||
| public Shape shape => _type_spec.shape; | |||
| public TF_DataType dtype => _type_spec.dtype; | |||
| public KerasTensor(TensorSpec type_spec, Shape inferred_value = null, string name = null) | |||
| { | |||
| _type_spec = type_spec; | |||
| _inferred_value = inferred_value; | |||
| _name = name; | |||
| } | |||
| public static KerasTensor from_tensor(Tensor tensor) | |||
| { | |||
| var type_spec = tensor.ToTensorSpec(); | |||
| Shape? inferred_value = default; | |||
| if (tensor.dtype == TF_DataType.TF_INT32 && tensor.rank < 2) | |||
| { | |||
| inferred_value = tf.ones(tensor).shape; | |||
| } | |||
| var kt = new KerasTensor(type_spec, inferred_value: inferred_value, name: tensor.name); | |||
| kt.original_tensors = tensor; | |||
| return kt; | |||
| } | |||
| public KerasTensor this[int idx] | |||
| => _original_tensors.First()[idx]; | |||
| public KerasTensor this[params Slice[] slices] | |||
| => _original_tensors.First()[slices]; | |||
| public override string ToString() | |||
| => _original_tensors.Length switch | |||
| { | |||
| > 1 => "[" + string.Join(", ", _original_tensors.Select(x => $"KerasTensor: shape={x.shape} dtype={x.dtype.as_numpy_name()}{GetInferredValueString()}")) + "]", | |||
| 1 => $"KerasTensor: shape={_original_tensors.shape} dtype={_original_tensors.dtype.as_numpy_name()}{GetInferredValueString()}", | |||
| _ => _original_tensors.ToString(), | |||
| }; | |||
| private string GetInferredValueString() | |||
| => _inferred_value == null ? "" : $" inferred_value={_inferred_value}"; | |||
| public static implicit operator Tensors(KerasTensor kt) | |||
| => kt._original_tensors; | |||
| public static implicit operator Tensor(KerasTensor kt) | |||
| { | |||
| Tensor tensor = kt._original_tensors; | |||
| tensor.IsFromKerasTensor = true; | |||
| return tensor; | |||
| } | |||
| public static implicit operator KerasTensor(Tensor tensor) | |||
| => from_tensor(tensor); | |||
| public static implicit operator KerasTensor(Tensors tensors) | |||
| => from_tensor(tensors.First()); | |||
| } | |||
| @@ -25,6 +25,27 @@ namespace Tensorflow.Keras | |||
| bool amsgrad = false, | |||
| string name = "Adam"); | |||
| /// <summary> | |||
| /// Adam enables L2 weight decay on gradients. | |||
| /// </summary> | |||
| /// <param name="learning_rate"></param> | |||
| /// <param name="weight_decay"></param> | |||
| /// <param name="beta_1"></param> | |||
| /// <param name="beta_2"></param> | |||
| /// <param name="epsilon"></param> | |||
| /// <param name="amsgrad"></param> | |||
| /// <param name="decay_params"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| IOptimizer AdamW(float learning_rate = 0.001f, | |||
| float weight_decay = 0.004f, | |||
| float beta_1 = 0.9f, | |||
| float beta_2 = 0.999f, | |||
| float epsilon = 1e-7f, | |||
| bool amsgrad = false, | |||
| List<string> no_decay_params = null, | |||
| string name = "AdamW"); | |||
| /// <summary> | |||
| /// Construct a new RMSprop optimizer. | |||
| /// </summary> | |||
| @@ -42,6 +63,6 @@ namespace Tensorflow.Keras | |||
| bool centered = false, | |||
| string name = "RMSprop"); | |||
| IOptimizer SGD(float learning_rate); | |||
| IOptimizer SGD(float learning_rate = 0.01f, float momentum = 0f); | |||
| } | |||
| } | |||
| @@ -1,4 +1,5 @@ | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.NumPy; | |||
| using Tensorflow.Training; | |||
| @@ -14,7 +15,7 @@ namespace Tensorflow.Keras | |||
| List<ILayer> Layers { get; } | |||
| List<INode> InboundNodes { get; } | |||
| List<INode> OutboundNodes { get; } | |||
| Tensors Apply(Tensors inputs, Tensor state = null, bool training = false); | |||
| Tensors Apply(Tensors inputs, Tensors states = null, bool? training = false, IOptionalArgs? optional_args = null); | |||
| List<IVariableV1> TrainableVariables { get; } | |||
| List<IVariableV1> TrainableWeights { get; } | |||
| List<IVariableV1> NonTrainableWeights { get; } | |||
| @@ -9,6 +9,10 @@ namespace Tensorflow.Keras.Layers | |||
| public ILayer Reshape(Shape target_shape); | |||
| public ILayer Reshape(object[] target_shape); | |||
| public ILayer UpSampling1D( | |||
| int size | |||
| ); | |||
| public ILayer UpSampling2D(Shape size = null, | |||
| string data_format = null, | |||
| string interpolation = "nearest"); | |||
| @@ -1,5 +1,7 @@ | |||
| using System; | |||
| using Tensorflow.Framework.Models; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Layers; | |||
| using Tensorflow.NumPy; | |||
| using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; | |||
| @@ -134,7 +136,7 @@ namespace Tensorflow.Keras.Layers | |||
| public ILayer GlobalMaxPooling1D(string data_format = "channels_last"); | |||
| public ILayer GlobalMaxPooling2D(string data_format = "channels_last"); | |||
| public Tensors Input(Shape shape = null, | |||
| public KerasTensor Input(Shape shape = null, | |||
| int batch_size = -1, | |||
| string name = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| @@ -159,6 +161,18 @@ namespace Tensorflow.Keras.Layers | |||
| public ILayer Normalization(Shape? input_shape = null, int? axis = -1, float? mean = null, float? variance = null, bool invert = false); | |||
| public ILayer LeakyReLU(float alpha = 0.3f); | |||
| public IRnnCell LSTMCell(int uints, | |||
| string activation = "tanh", | |||
| string recurrent_activation = "sigmoid", | |||
| bool use_bias = true, | |||
| string kernel_initializer = "glorot_uniform", | |||
| string recurrent_initializer = "orthogonal", | |||
| string bias_initializer = "zeros", | |||
| bool unit_forget_bias = true, | |||
| float dropout = 0f, | |||
| float recurrent_dropout = 0f, | |||
| int implementation = 2); | |||
| public ILayer LSTM(int units, | |||
| Activation activation = null, | |||
| Activation recurrent_activation = null, | |||
| @@ -192,6 +206,19 @@ namespace Tensorflow.Keras.Layers | |||
| float offset = 0, | |||
| Shape input_shape = null); | |||
| public IRnnCell SimpleRNNCell( | |||
| int units, | |||
| string activation = "tanh", | |||
| bool use_bias = true, | |||
| string kernel_initializer = "glorot_uniform", | |||
| string recurrent_initializer = "orthogonal", | |||
| string bias_initializer = "zeros", | |||
| float dropout = 0f, | |||
| float recurrent_dropout = 0f); | |||
| public IRnnCell StackedRNNCells( | |||
| IEnumerable<IRnnCell> cells); | |||
| public ILayer SimpleRNN(int units, | |||
| string activation = "tanh", | |||
| string kernel_initializer = "glorot_uniform", | |||
| @@ -200,6 +227,69 @@ namespace Tensorflow.Keras.Layers | |||
| bool return_sequences = false, | |||
| bool return_state = false); | |||
| public ILayer RNN( | |||
| IRnnCell cell, | |||
| bool return_sequences = false, | |||
| bool return_state = false, | |||
| bool go_backwards = false, | |||
| bool stateful = false, | |||
| bool unroll = false, | |||
| bool time_major = false | |||
| ); | |||
| public ILayer RNN( | |||
| IEnumerable<IRnnCell> cell, | |||
| bool return_sequences = false, | |||
| bool return_state = false, | |||
| bool go_backwards = false, | |||
| bool stateful = false, | |||
| bool unroll = false, | |||
| bool time_major = false | |||
| ); | |||
| public IRnnCell GRUCell( | |||
| int units, | |||
| string activation = "tanh", | |||
| string recurrent_activation = "sigmoid", | |||
| bool use_bias = true, | |||
| string kernel_initializer = "glorot_uniform", | |||
| string recurrent_initializer = "orthogonal", | |||
| string bias_initializer = "zeros", | |||
| float dropout = 0f, | |||
| float recurrent_dropout = 0f, | |||
| bool reset_after = true); | |||
| public ILayer GRU( | |||
| int units, | |||
| string activation = "tanh", | |||
| string recurrent_activation = "sigmoid", | |||
| bool use_bias = true, | |||
| string kernel_initializer = "glorot_uniform", | |||
| string recurrent_initializer = "orthogonal", | |||
| string bias_initializer = "zeros", | |||
| float dropout = 0f, | |||
| float recurrent_dropout = 0f, | |||
| bool return_sequences = false, | |||
| bool return_state = false, | |||
| bool go_backwards = false, | |||
| bool stateful = false, | |||
| bool unroll = false, | |||
| bool time_major = false, | |||
| bool reset_after = true | |||
| ); | |||
| /// <summary> | |||
| /// Bidirectional wrapper for RNNs. | |||
| /// </summary> | |||
| /// <param name="layer">`keras.layers.RNN` instance, such as `keras.layers.LSTM` or `keras.layers.GRU`</param> | |||
| /// automatically.</param> | |||
| /// <returns></returns> | |||
| public ILayer Bidirectional( | |||
| ILayer layer, | |||
| string merge_mode = "concat", | |||
| NDArray weights = null, | |||
| ILayer backward_layer = null); | |||
| public ILayer Subtract(); | |||
| } | |||
| } | |||
| @@ -0,0 +1,25 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| public interface IRnnCell: ILayer | |||
| { | |||
| /// <summary> | |||
| /// If the derived class tends to not implement it, please return null. | |||
| /// </summary> | |||
| INestStructure<long>? StateSize { get; } | |||
| /// <summary> | |||
| /// If the derived class tends to not implement it, please return null. | |||
| /// </summary> | |||
| INestStructure<long>? OutputSize { get; } | |||
| /// <summary> | |||
| /// Whether the optional RNN args are supported when appying the layer. | |||
| /// In other words, whether `Apply` is overwrited with process of `RnnOptionalArgs`. | |||
| /// </summary> | |||
| bool SupportOptionalArgs { get; } | |||
| Tensors GetInitialState(Tensors inputs, Tensor batch_size, TF_DataType dtype); | |||
| } | |||
| } | |||
| @@ -0,0 +1,12 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| public interface IStackedRnnCells : IRnnCell | |||
| { | |||
| int Count { get; } | |||
| IRnnCell this[int idx] { get; } | |||
| } | |||
| } | |||
| @@ -3,6 +3,7 @@ using Newtonsoft.Json; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Saving.Json | |||
| { | |||
| @@ -6,6 +6,7 @@ using System.Text; | |||
| using System.Diagnostics; | |||
| using OneOf.Types; | |||
| using Tensorflow.Keras.Saving.Json; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Saving | |||
| { | |||
| @@ -74,8 +74,3 @@ namespace Tensorflow | |||
| => IsScalar ? $"{axis[0]}" : $"({string.Join(", ", axis)})"; | |||
| } | |||
| } | |||
| namespace System.Runtime.CompilerServices | |||
| { | |||
| internal static class IsExternalInit { } | |||
| } | |||
| @@ -107,9 +107,15 @@ namespace Tensorflow.NumPy | |||
| public static implicit operator NDArray(bool value) | |||
| => new NDArray(value); | |||
| public static implicit operator NDArray(byte value) | |||
| => new NDArray(value); | |||
| public static implicit operator NDArray(int value) | |||
| => new NDArray(value); | |||
| public static implicit operator NDArray(long value) | |||
| => new NDArray(value); | |||
| public static implicit operator NDArray(float value) | |||
| => new NDArray(value); | |||
| @@ -7,7 +7,7 @@ namespace Tensorflow.NumPy | |||
| { | |||
| public class NDArrayRender | |||
| { | |||
| public static string ToString(NDArray array) | |||
| public static string ToString(NDArray array, int maxLength = 10) | |||
| { | |||
| Shape shape = array.shape; | |||
| if (shape.IsScalar) | |||
| @@ -15,12 +15,12 @@ namespace Tensorflow.NumPy | |||
| var s = new StringBuilder(); | |||
| s.Append("array("); | |||
| Build(s, array); | |||
| Build(s, array, maxLength); | |||
| s.Append(")"); | |||
| return s.ToString(); | |||
| } | |||
| static void Build(StringBuilder s, NDArray array) | |||
| static void Build(StringBuilder s, NDArray array, int maxLength) | |||
| { | |||
| var shape = array.shape; | |||
| @@ -35,11 +35,11 @@ namespace Tensorflow.NumPy | |||
| var len = shape[0]; | |||
| s.Append("["); | |||
| if (len <= 10) | |||
| if (len <= maxLength) | |||
| { | |||
| for (int i = 0; i < len; i++) | |||
| { | |||
| Build(s, array[i]); | |||
| Build(s, array[i], maxLength); | |||
| if (i < len - 1) | |||
| { | |||
| s.Append(", "); | |||
| @@ -49,9 +49,9 @@ namespace Tensorflow.NumPy | |||
| } | |||
| else | |||
| { | |||
| for (int i = 0; i < 5; i++) | |||
| for (int i = 0; i < maxLength / 2; i++) | |||
| { | |||
| Build(s, array[i]); | |||
| Build(s, array[i], maxLength); | |||
| if (i < len - 1) | |||
| { | |||
| s.Append(", "); | |||
| @@ -62,9 +62,9 @@ namespace Tensorflow.NumPy | |||
| s.Append(" ... "); | |||
| s.AppendLine(); | |||
| for (int i = (int)len - 5; i < len; i++) | |||
| for (int i = (int)len - maxLength / 2; i < len; i++) | |||
| { | |||
| Build(s, array[i]); | |||
| Build(s, array[i], maxLength); | |||
| if (i < len - 1) | |||
| { | |||
| s.Append(", "); | |||
| @@ -13,6 +13,10 @@ namespace Tensorflow.NumPy | |||
| public static NDArray argmax(NDArray a, Axis? axis = null) | |||
| => new NDArray(math_ops.argmax(a, axis ?? 0)); | |||
| [AutoNumPy] | |||
| public static NDArray argmin(NDArray a, Axis? axis = null) | |||
| => new NDArray(math_ops.argmin(a, axis ?? 0)); | |||
| [AutoNumPy] | |||
| public static NDArray argsort(NDArray a, Axis? axis = null) | |||
| => new NDArray(sort_ops.argsort(a, axis: axis ?? -1)); | |||
| @@ -10,10 +10,10 @@ namespace Tensorflow.NumPy | |||
| public partial class np | |||
| { | |||
| [AutoNumPy] | |||
| public static NDArray amin(NDArray x, int axis = 0) => new NDArray(tf.arg_min(x, axis)); | |||
| public static NDArray amin(NDArray x, int axis = 0) => new NDArray(tf.min(x, axis)); | |||
| [AutoNumPy] | |||
| public static NDArray amax(NDArray x, int axis = 0) => new NDArray(tf.math.argmax(x, axis)); | |||
| public static NDArray amax(NDArray x, int axis = 0) => new NDArray(tf.max(x, axis)); | |||
| [AutoNumPy] | |||
| public static NDArray average(NDArray a, int axis = -1, NDArray? weights = null, bool returned = false) | |||
| @@ -49,9 +49,30 @@ namespace Tensorflow.NumPy | |||
| [AutoNumPy] | |||
| public static NDArray prod<T>(params T[] array) where T : unmanaged | |||
| => new NDArray(tf.reduce_prod(new NDArray(array))); | |||
| [AutoNumPy] | |||
| public static NDArray dot(NDArray x1, NDArray x2, NDArray? axes = null, string? name = null) | |||
| { | |||
| //if axes mentioned | |||
| if (axes != null) | |||
| { | |||
| return new NDArray(tf.dot_prod(x1, x2, axes, name)); | |||
| } | |||
| if (x1.shape.ndim > 1) | |||
| { | |||
| x1 = GetFlattenArray(x1); | |||
| } | |||
| if (x2.shape.ndim > 1) | |||
| { | |||
| x2 = GetFlattenArray(x2); | |||
| } | |||
| //if axes not mentioned, default 0,0 | |||
| return new NDArray(tf.dot_prod(x1, x2, axes: new int[] { 0, 0 }, name)); | |||
| } | |||
| [AutoNumPy] | |||
| public static NDArray power(NDArray x, NDArray y) => new NDArray(tf.pow(x, y)); | |||
| [AutoNumPy] | |||
| public static NDArray square(NDArray x) => new NDArray(tf.square(x)); | |||
| [AutoNumPy] | |||
| public static NDArray sin(NDArray x) => new NDArray(math_ops.sin(x)); | |||
| @@ -19,13 +19,14 @@ using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.Saving.Common; | |||
| using Tensorflow.NumPy; | |||
| namespace Tensorflow | |||
| { | |||
| [JsonConverter(typeof(CustomizedShapeJsonConverter))] | |||
| public class Shape | |||
| public class Shape : INestStructure<long> | |||
| { | |||
| public int ndim => _dims == null ? -1 : _dims.Length; | |||
| long[] _dims; | |||
| @@ -41,6 +42,27 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| public NestType NestType => NestType.List; | |||
| public int ShallowNestedCount => ndim; | |||
| /// <summary> | |||
| /// The total item count of depth 1 of the nested structure. | |||
| /// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5. | |||
| /// </summary> | |||
| public int TotalNestedCount => ndim; | |||
| public IEnumerable<long> Flatten() => dims.Select(x => x); | |||
| public INestStructure<TOut> MapStructure<TOut>(Func<long, TOut> func) | |||
| { | |||
| return new NestList<TOut>(dims.Select(x => func(x))); | |||
| } | |||
| public Nest<long> AsNest() | |||
| { | |||
| return new NestList<long>(Flatten()).AsNest(); | |||
| } | |||
| #region https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/proposals/csharp-8.0/ranges | |||
| public int Length => ndim; | |||
| public long[] Slice(int start, int length) | |||
| @@ -0,0 +1,22 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.NumPy; | |||
| namespace Tensorflow.Operations.Initializers | |||
| { | |||
| /// <summary> | |||
| /// An initializer specially used for debugging (to load weights from disk). | |||
| /// </summary> | |||
| class NpyLoadInitializer : IInitializer | |||
| { | |||
| string _path; | |||
| public NpyLoadInitializer(string path) { _path = path; } | |||
| public string ClassName => ""; | |||
| public IDictionary<string, object> Config => new Dictionary<string, object>(); | |||
| public Tensor Apply(InitializerArgs args) | |||
| { | |||
| return np.load(_path); | |||
| } | |||
| } | |||
| } | |||
| @@ -53,13 +53,12 @@ public class Orthogonal : IInitializer | |||
| // Compute the qr factorization | |||
| var (q, r) = tf.linalg.qr(a, full_matrices: false); | |||
| // Make Q uniform | |||
| var d = tf.linalg.tensor_diag_part(r); | |||
| var d = tf.linalg.tensor_diag_part(r.Single); | |||
| q *= tf.sign(d); | |||
| if (num_rows < num_cols) | |||
| { | |||
| // q = tf.linalg.matrix_transpose(q); | |||
| throw new NotImplementedException(""); | |||
| q = array_ops.matrix_transpose(q); | |||
| } | |||
| return _gain * tf.reshape(q, shape); | |||
| @@ -11,6 +11,7 @@ namespace Tensorflow | |||
| /// Basic LSTM recurrent network cell. | |||
| /// The implementation is based on: http://arxiv.org/abs/1409.2329. | |||
| /// </summary> | |||
| [Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] | |||
| public class BasicLstmCell : LayerRnnCell | |||
| { | |||
| int _num_units; | |||
| @@ -88,7 +89,7 @@ namespace Tensorflow | |||
| gate_inputs = nn_ops.bias_add(gate_inputs, _bias); | |||
| // i = input_gate, j = new_input, f = forget_gate, o = output_gate | |||
| var tensors = array_ops.split(value: gate_inputs, num_split: 4, axis: one); | |||
| var tensors = array_ops.split(value: gate_inputs, num_or_size_splits: 4, axis: one); | |||
| var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]); | |||
| var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype); | |||
| @@ -20,6 +20,7 @@ using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| [Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] | |||
| public class BasicRnnCell : LayerRnnCell | |||
| { | |||
| int _num_units; | |||
| @@ -19,6 +19,7 @@ using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| [Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] | |||
| public class LayerRnnCell : RnnCell | |||
| { | |||
| protected InputSpec inputSpec; | |||
| @@ -16,10 +16,11 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Layers; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.NumPy; | |||
| using Tensorflow.Operations; | |||
| @@ -50,7 +51,8 @@ namespace Tensorflow | |||
| /// matching structure of Tensors having shape `[batch_size].concatenate(s)` | |||
| /// for each `s` in `self.batch_size`. | |||
| /// </summary> | |||
| public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell | |||
| [Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] | |||
| public abstract class RnnCell : ILayer, IRnnCell | |||
| { | |||
| /// <summary> | |||
| /// Attribute that indicates whether the cell is a TF RNN cell, due the slight | |||
| @@ -142,7 +144,7 @@ namespace Tensorflow | |||
| throw new NotImplementedException("_zero_state_tensors"); | |||
| } | |||
| public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| public Tensors Apply(Tensors inputs, Tensors state = null, bool? is_training = false, IOptionalArgs? optional_args = null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| @@ -173,5 +175,18 @@ namespace Tensorflow | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public (Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public Tensors GetInitialState(Tensors inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public INestStructure<long> StateSize => throw new NotImplementedException(); | |||
| public INestStructure<long> OutputSize => throw new NotImplementedException(); | |||
| public bool IsTFRnnCell => throw new NotImplementedException(); | |||
| public bool SupportOptionalArgs => throw new NotImplementedException(); | |||
| } | |||
| } | |||
| @@ -15,9 +15,11 @@ | |||
| ******************************************************************************/ | |||
| using Google.Protobuf; | |||
| using Google.Protobuf.Collections; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Functions; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.OpDef.Types; | |||
| @@ -387,9 +389,13 @@ namespace Tensorflow | |||
| case "list(type)": | |||
| attr_value.List.Type.AddRange((value as IList<TF_DataType>).Select(x => _MakeType(x, attr_def))); | |||
| break; | |||
| case "list(float)": | |||
| if (value != null) | |||
| attr_value.List.F.AddRange((value as IEnumerable<float>).ToArray()); | |||
| break; | |||
| case "list(int)": | |||
| if (value != null) | |||
| attr_value.List.I.AddRange((value as int[]).Select(x => Convert.ToInt64(x))); | |||
| attr_value.List.I.AddRange((value as IEnumerable<int>).Select(x => Convert.ToInt64(x))); | |||
| break; | |||
| case "bool": | |||
| attr_value.B = (bool)value; | |||
| @@ -420,6 +426,15 @@ namespace Tensorflow | |||
| case "list(shape)": | |||
| attr_value.List.Shape.AddRange((value as Shape[]).Select(x => _MakeShape(x, attr_def))); | |||
| break; | |||
| case "func": | |||
| attr_value.Func = _MakeFunc(value, attr_def.Name); | |||
| break; | |||
| case "list(func)": | |||
| attr_value.List.Func.AddRange(_MakeFuncList(value, attr_def.Name)); | |||
| break; | |||
| case "list(string)": | |||
| attr_value.List.S.AddRange((value as IEnumerable<string>).Select(x => ByteString.CopyFromUtf8(x))); | |||
| break; | |||
| default: | |||
| throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos."); | |||
| } | |||
| @@ -427,6 +442,47 @@ namespace Tensorflow | |||
| return attr_value; | |||
| } | |||
| private NameAttrList _MakeFunc(object func, string arg_name) | |||
| { | |||
| if(func is NameAttrList attrList) | |||
| { | |||
| return attrList; | |||
| } | |||
| NameAttrList fn_attr; | |||
| if(func is string funcStr) | |||
| { | |||
| fn_attr = new NameAttrList() { Name = funcStr }; | |||
| } | |||
| else if(func is ConcreteFunction concrete) | |||
| { | |||
| concrete.AddTograph(ops.get_default_graph()); | |||
| fn_attr = concrete.AsNameAttrList; | |||
| } | |||
| else if(func is EagerDefinedFunction eager) | |||
| { | |||
| eager.AddToGraph(ops.get_default_graph()); | |||
| fn_attr = new NameAttrList() { Name = eager.Name }; | |||
| } | |||
| else | |||
| { | |||
| throw new TypeError($"Don't know how to convert {func} to a func for argument {arg_name}"); | |||
| } | |||
| return fn_attr; | |||
| } | |||
| private List<NameAttrList> _MakeFuncList(object funcList, string arg_name) | |||
| { | |||
| List<NameAttrList> res = new List<NameAttrList>(); | |||
| if(funcList is IEnumerable enumerable) | |||
| { | |||
| foreach(var func in enumerable) | |||
| { | |||
| res.Add(_MakeFunc(func, arg_name)); | |||
| } | |||
| } | |||
| return res; | |||
| } | |||
| private bool _IsListParameter(ArgDef arg) | |||
| { | |||
| if (!String.IsNullOrEmpty(arg.NumberAttr)) | |||