# Conflicts: # src/TensorFlowNET.Core/Tensorflow.Binding.csproj # src/TensorFlowNET.Keras/Datasets/Imdb.cstags/v0.110.4-Transformer-Model
| @@ -16,6 +16,7 @@ | |||||
| using System; | using System; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using static Tensorflow.CppShapeInferenceResult.Types; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -50,6 +51,35 @@ namespace Tensorflow | |||||
| return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle); | return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle); | ||||
| } | } | ||||
| public unsafe static byte[] ByteStringPiece(Buffer? handle) | |||||
| { | |||||
| if (handle is null) | |||||
| { | |||||
| return new byte[0]; | |||||
| } | |||||
| var data = handle.ToArray(); | |||||
| return data; | |||||
| } | |||||
| public unsafe static byte[] ByteStringPieceFromNativeString(IntPtr handle) | |||||
| { | |||||
| if (handle == IntPtr.Zero) | |||||
| { | |||||
| return new byte[0]; | |||||
| } | |||||
| byte* str_data = (byte*)handle.ToPointer(); | |||||
| List<byte> bytes = new List<byte>(); | |||||
| byte current = 255; | |||||
| while (current != ((byte)'\0')) | |||||
| { | |||||
| current = *(str_data++); | |||||
| bytes.Add(current); | |||||
| } | |||||
| var data = bytes.ToArray(); | |||||
| return data; | |||||
| } | |||||
| [UnmanagedFunctionPointer(CallingConvention.Winapi)] | [UnmanagedFunctionPointer(CallingConvention.Winapi)] | ||||
| public delegate void Deallocator(IntPtr data, IntPtr size, ref DeallocatorArgs args); | public delegate void Deallocator(IntPtr data, IntPtr size, ref DeallocatorArgs args); | ||||
| @@ -10,7 +10,7 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status); | public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status); | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||||
| public static extern SafeBufferHandle TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status); | public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status); | ||||
| } | } | ||||
| @@ -91,8 +91,7 @@ namespace Tensorflow | |||||
| return identity(values.First(), name: scope); | return identity(values.First(), name: scope); | ||||
| }); | }); | ||||
| } | } | ||||
| return gen_array_ops.concat_v2(values.ToArray(), ops.convert_to_tensor(axis), name: name); | |||||
| return array_ops.concat(values.ToArray(), axis, name: name); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -163,14 +162,17 @@ namespace Tensorflow | |||||
| /// Reverses specific dimensions of a tensor. | /// Reverses specific dimensions of a tensor. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="tensor"></param> | /// <param name="tensor"></param> | ||||
| /// <param name="axis"></param> | |||||
| /// <param name="axis">The indices of the dimensions to reverse. Must be in the range [-rank(tensor), rank(tensor)).</param> | |||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public Tensor reverse(Tensor tensor, int[] axis, string name = null) | |||||
| => gen_array_ops.reverse(tensor, ops.convert_to_tensor(axis), name: name); | |||||
| public Tensor reverse(Tensor tensor, Tensor axis, string name = null) | |||||
| => gen_array_ops.reverse(tensor, axis, name: name); | |||||
| public Tensor reverse(Tensor tensor, Axis axis, string name = null) | |||||
| { | |||||
| if (axis.IsScalar) | |||||
| { | |||||
| axis = new Axis(axis.axis); | |||||
| } | |||||
| return array_ops.reverse(tensor, axis, name: name); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the rank of a tensor. | /// Returns the rank of a tensor. | ||||
| @@ -46,10 +46,10 @@ namespace Tensorflow | |||||
| Tensor loop_vars, | Tensor loop_vars, | ||||
| int parallel_iterations = 10) | int parallel_iterations = 10) | ||||
| { | { | ||||
| Func<Tensor[], Tensor> cond1 = x | |||||
| Func<Tensors, Tensor> cond1 = x | |||||
| => cond(x[0]); | => cond(x[0]); | ||||
| Func<Tensor[], Tensor[]> body1 = x | |||||
| Func<Tensors, Tensors> body1 = x | |||||
| => new[] { body(x[0]) }; | => new[] { body(x[0]) }; | ||||
| var results = control_flow_ops.while_loop(cond1, | var results = control_flow_ops.while_loop(cond1, | ||||
| @@ -58,9 +58,9 @@ namespace Tensorflow | |||||
| return results[0]; | return results[0]; | ||||
| } | } | ||||
| public Tensor[] while_loop(Func<Tensor[], Tensor> cond, | |||||
| Func<Tensor[], Tensor[]> body, | |||||
| Tensor[] loop_vars, | |||||
| public Tensor[] while_loop(Func<Tensors, Tensor> cond, | |||||
| Func<Tensors, Tensors> body, | |||||
| Tensors loop_vars, | |||||
| int parallel_iterations = 10, | int parallel_iterations = 10, | ||||
| string name = null) | string name = null) | ||||
| => control_flow_ops.while_loop(cond, body, loop_vars, | => control_flow_ops.while_loop(cond, body, loop_vars, | ||||
| @@ -14,6 +14,10 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using OneOf.Types; | |||||
| using System; | |||||
| using System.Buffers.Text; | |||||
| using Tensorflow.Contexts; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -162,17 +166,108 @@ namespace Tensorflow | |||||
| public Tensor sobel_edges(Tensor image) | public Tensor sobel_edges(Tensor image) | ||||
| => image_ops_impl.sobel_edges(image); | => image_ops_impl.sobel_edges(image); | ||||
| public Tensor decode_jpeg(Tensor contents, | |||||
| int channels = 0, | |||||
| int ratio = 1, | |||||
| bool fancy_upscaling = true, | |||||
| bool try_recover_truncated = false, | |||||
| int acceptable_fraction = 1, | |||||
| string dct_method = "", | |||||
| string name = null) | |||||
| => gen_image_ops.decode_jpeg(contents, channels: channels, ratio: ratio, | |||||
| fancy_upscaling: fancy_upscaling, try_recover_truncated: try_recover_truncated, | |||||
| acceptable_fraction: acceptable_fraction, dct_method: dct_method); | |||||
| /// <summary> | |||||
| /// Adjust contrast of RGB or grayscale images. | |||||
| /// </summary> | |||||
| /// <param name="images">Images to adjust. At least 3-D.</param> | |||||
| /// <param name="contrast_factor"></param> | |||||
| /// <param name="name">A float multiplier for adjusting contrast.</param> | |||||
| /// <returns>The contrast-adjusted image or images.</returns> | |||||
| public Tensor adjust_contrast(Tensor images, float contrast_factor, string name = null) | |||||
| => gen_image_ops.adjust_contrastv2(images, contrast_factor, name); | |||||
| /// <summary> | |||||
| /// Adjust hue of RGB images. | |||||
| /// </summary> | |||||
| /// <param name="images">RGB image or images. The size of the last dimension must be 3.</param> | |||||
| /// <param name="delta">float. How much to add to the hue channel.</param> | |||||
| /// <param name="name">A name for this operation (optional).</param> | |||||
| /// <returns>Adjusted image(s), same shape and DType as `image`.</returns> | |||||
| /// <exception cref="ValueError">if `delta` is not in the interval of `[-1, 1]`.</exception> | |||||
| public Tensor adjust_hue(Tensor images, float delta, string name = null) | |||||
| { | |||||
| if (tf.Context.executing_eagerly()) | |||||
| { | |||||
| if (delta < -1f || delta > 1f) | |||||
| throw new ValueError("delta must be in the interval [-1, 1]"); | |||||
| } | |||||
| return gen_image_ops.adjust_hue(images, delta, name: name); | |||||
| } | |||||
| /// <summary> | |||||
| /// Adjust saturation of RGB images. | |||||
| /// </summary> | |||||
| /// <param name="image">RGB image or images. The size of the last dimension must be 3.</param> | |||||
| /// <param name="saturation_factor">float. Factor to multiply the saturation by.</param> | |||||
| /// <param name="name">A name for this operation (optional).</param> | |||||
| /// <returns>Adjusted image(s), same shape and DType as `image`.</returns> | |||||
| public Tensor adjust_saturation(Tensor image, float saturation_factor, string name = null) | |||||
| => gen_image_ops.adjust_saturation(image, saturation_factor, name); | |||||
| /// <summary> | |||||
| /// Greedily selects a subset of bounding boxes in descending order of score. | |||||
| /// </summary> | |||||
| /// <param name="boxes"> | |||||
| /// A 4-D float `Tensor` of shape `[batch_size, num_boxes, q, 4]`. If `q` | |||||
| /// is 1 then same boxes are used for all classes otherwise, if `q` is equal | |||||
| /// to number of classes, class-specific boxes are used. | |||||
| /// </param> | |||||
| /// <param name="scores"> | |||||
| /// A 3-D float `Tensor` of shape `[batch_size, num_boxes, num_classes]` | |||||
| /// representing a single score corresponding to each box(each row of boxes). | |||||
| /// </param> | |||||
| /// <param name="max_output_size_per_class"> | |||||
| /// A scalar integer `Tensor` representing the | |||||
| /// maximum number of boxes to be selected by non-max suppression per class | |||||
| /// </param> | |||||
| /// <param name="max_total_size"> | |||||
| /// A int32 scalar representing maximum number of boxes retained | |||||
| /// over all classes.Note that setting this value to a large number may | |||||
| /// result in OOM error depending on the system workload. | |||||
| /// </param> | |||||
| /// <param name="iou_threshold"> | |||||
| /// A float representing the threshold for deciding whether boxes | |||||
| /// overlap too much with respect to IOU. | |||||
| /// </param> | |||||
| /// <param name="score_threshold"> | |||||
| /// A float representing the threshold for deciding when to | |||||
| /// remove boxes based on score. | |||||
| /// </param> | |||||
| /// <param name="pad_per_class"> | |||||
| /// If false, the output nmsed boxes, scores and classes are | |||||
| /// padded/clipped to `max_total_size`. If true, the output nmsed boxes, scores and classes are padded to be of length `max_size_per_class`*`num_classes`, | |||||
| /// unless it exceeds `max_total_size` in which case it is clipped to `max_total_size`. Defaults to false. | |||||
| /// </param> | |||||
| /// <param name="clip_boxes"> | |||||
| /// If true, the coordinates of output nmsed boxes will be clipped | |||||
| /// to[0, 1]. If false, output the box coordinates as it is. Defaults to true. | |||||
| /// </param> | |||||
| /// <returns> | |||||
| /// 'nmsed_boxes': A [batch_size, max_detections, 4] float32 tensor containing the non-max suppressed boxes. | |||||
| /// 'nmsed_scores': A [batch_size, max_detections] float32 tensor containing the scores for the boxes. | |||||
| /// 'nmsed_classes': A [batch_size, max_detections] float32 tensor containing the class for boxes. | |||||
| /// 'valid_detections': A [batch_size] int32 tensor indicating the number of | |||||
| /// valid detections per batch item. Only the top valid_detections[i] entries | |||||
| /// in nms_boxes[i], nms_scores[i] and nms_class[i] are valid. The rest of the | |||||
| /// entries are zero paddings. | |||||
| /// </returns> | |||||
| public (Tensor, Tensor, Tensor, Tensor) combined_non_max_suppression( | |||||
| Tensor boxes, | |||||
| Tensor scores, | |||||
| int max_output_size_per_class, | |||||
| int max_total_size, | |||||
| float iou_threshold, | |||||
| float score_threshold, | |||||
| bool pad_per_class = false, | |||||
| bool clip_boxes = true) | |||||
| { | |||||
| var iou_threshold_t = ops.convert_to_tensor(iou_threshold, TF_DataType.TF_FLOAT, name: "iou_threshold"); | |||||
| var score_threshold_t = ops.convert_to_tensor(score_threshold, TF_DataType.TF_FLOAT, name: "score_threshold"); | |||||
| var max_total_size_t = ops.convert_to_tensor(max_total_size); | |||||
| var max_output_size_per_class_t = ops.convert_to_tensor(max_output_size_per_class); | |||||
| return gen_image_ops.combined_non_max_suppression(boxes, scores, max_output_size_per_class_t, max_total_size_t, | |||||
| iou_threshold_t, score_threshold_t, pad_per_class, clip_boxes); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Extracts crops from the input image tensor and resizes them using bilinear sampling or nearest neighbor sampling (possibly with aspect ratio change) to a common output size specified by crop_size. This is more general than the crop_to_bounding_box op which extracts a fixed size slice from the input image and does not allow resizing or aspect ratio change. | /// Extracts crops from the input image tensor and resizes them using bilinear sampling or nearest neighbor sampling (possibly with aspect ratio change) to a common output size specified by crop_size. This is more general than the crop_to_bounding_box op which extracts a fixed size slice from the input image and does not allow resizing or aspect ratio change. | ||||
| @@ -187,7 +282,19 @@ namespace Tensorflow | |||||
| /// <param name="name">A name for the operation (optional).</param> | /// <param name="name">A name for the operation (optional).</param> | ||||
| /// <returns>A 4-D tensor of shape [num_boxes, crop_height, crop_width, depth].</returns> | /// <returns>A 4-D tensor of shape [num_boxes, crop_height, crop_width, depth].</returns> | ||||
| public Tensor crop_and_resize(Tensor image, Tensor boxes, Tensor box_ind, Tensor crop_size, string method = "bilinear", float extrapolation_value = 0f, string name = null) => | public Tensor crop_and_resize(Tensor image, Tensor boxes, Tensor box_ind, Tensor crop_size, string method = "bilinear", float extrapolation_value = 0f, string name = null) => | ||||
| image_ops_impl.crop_and_resize(image, boxes, box_ind, crop_size, method, extrapolation_value, name); | |||||
| gen_image_ops.crop_and_resize(image, boxes, box_ind, crop_size, method, extrapolation_value, name); | |||||
| public Tensor decode_jpeg(Tensor contents, | |||||
| int channels = 0, | |||||
| int ratio = 1, | |||||
| bool fancy_upscaling = true, | |||||
| bool try_recover_truncated = false, | |||||
| int acceptable_fraction = 1, | |||||
| string dct_method = "", | |||||
| string name = null) | |||||
| => gen_image_ops.decode_jpeg(contents, channels: channels, ratio: ratio, | |||||
| fancy_upscaling: fancy_upscaling, try_recover_truncated: try_recover_truncated, | |||||
| acceptable_fraction: acceptable_fraction, dct_method: dct_method); | |||||
| public Tensor extract_glimpse(Tensor input, Tensor size, Tensor offsets, bool centered = true, bool normalized = true, | public Tensor extract_glimpse(Tensor input, Tensor size, Tensor offsets, bool centered = true, bool normalized = true, | ||||
| bool uniform_noise = true, string name = null) | bool uniform_noise = true, string name = null) | ||||
| @@ -14,6 +14,7 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using Tensorflow.NumPy; | |||||
| using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -42,10 +43,20 @@ namespace Tensorflow | |||||
| public Tensor multiply(Tensor x, Tensor y, string name = null) | public Tensor multiply(Tensor x, Tensor y, string name = null) | ||||
| => math_ops.multiply(x, y, name: name); | => math_ops.multiply(x, y, name: name); | ||||
| public Tensor divide_no_nan(Tensor a, Tensor b, string name = null) | public Tensor divide_no_nan(Tensor a, Tensor b, string name = null) | ||||
| => math_ops.div_no_nan(a, b); | => math_ops.div_no_nan(a, b); | ||||
| /// <summary> | |||||
| /// Computes the Euclidean norm of elements across dimensions of a tensor. | |||||
| /// </summary> | |||||
| /// <param name="input_tensor">The tensor to reduce. Should have numeric type.</param> | |||||
| /// <param name="axis">The dimensions to reduce. If `None` (the default), reduces all dimensions.Must be in the range `[-rank(input_tensor), rank(input_tensor))`</param> | |||||
| /// <param name="keepdims">If true, retains reduced dimensions with length 1.</param> | |||||
| /// <param name="name">A name for the operation (optional).</param> | |||||
| /// <returns>The reduced tensor, of the same dtype as the input_tensor.</returns> | |||||
| public Tensor reduce_euclidean_norm(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) | |||||
| => math_ops.reduce_euclidean_norm(input_tensor, axis: axis, keepdims: keepdims, name); | |||||
| public Tensor square(Tensor x, string name = null) | public Tensor square(Tensor x, string name = null) | ||||
| => math_ops.square(x, name: name); | => math_ops.square(x, name: name); | ||||
| @@ -354,7 +365,7 @@ namespace Tensorflow | |||||
| => a / b; | => a / b; | ||||
| public Tensor sqrt(Tensor a, string name = null) | public Tensor sqrt(Tensor a, string name = null) | ||||
| => gen_math_ops.sqrt(a, name); | |||||
| => math_ops.sqrt(a, name); | |||||
| public Tensor sign(Tensor a, string name = null) | public Tensor sign(Tensor a, string name = null) | ||||
| => gen_math_ops.sign(a, name); | => gen_math_ops.sign(a, name); | ||||
| @@ -452,7 +463,18 @@ namespace Tensorflow | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public Tensor multiply<Tx, Ty>(Tx x, Ty y, string name = null) | public Tensor multiply<Tx, Ty>(Tx x, Ty y, string name = null) | ||||
| => gen_math_ops.mul(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name); | => gen_math_ops.mul(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name); | ||||
| /// <summary> | |||||
| /// return scalar product | |||||
| /// </summary> | |||||
| /// <typeparam name="Tx"></typeparam> | |||||
| /// <typeparam name="Ty"></typeparam> | |||||
| /// <param name="x"></param> | |||||
| /// <param name="y"></param> | |||||
| /// <param name="axes"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public Tensor dot_prod<Tx, Ty>(Tx x, Ty y, NDArray axes, string name = null) | |||||
| => math_ops.tensordot(convert_to_tensor(x), convert_to_tensor(y), axes, name: name); | |||||
| public Tensor negative(Tensor x, string name = null) | public Tensor negative(Tensor x, string name = null) | ||||
| => gen_math_ops.neg(x, name); | => gen_math_ops.neg(x, name); | ||||
| @@ -600,5 +622,7 @@ namespace Tensorflow | |||||
| => gen_math_ops.squared_difference(x: x, y: y, name: name); | => gen_math_ops.squared_difference(x: x, y: y, name: name); | ||||
| public Tensor complex(Tensor real, Tensor imag, Tensorflow.TF_DataType? dtype = null, | public Tensor complex(Tensor real, Tensor imag, Tensorflow.TF_DataType? dtype = null, | ||||
| string name = null) => gen_ops.complex(real, imag, dtype, name); | string name = null) => gen_ops.complex(real, imag, dtype, name); | ||||
| public Tensor exp(Tensor x, | |||||
| string name = null) => gen_math_ops.exp(x, name); | |||||
| } | } | ||||
| } | } | ||||
| @@ -14,6 +14,7 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Xml.Linq; | |||||
| using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
| using Tensorflow.Operations.Activation; | using Tensorflow.Operations.Activation; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -126,6 +127,26 @@ namespace Tensorflow | |||||
| name: name, | name: name, | ||||
| exponential_avg_factor: exponential_avg_factor); | exponential_avg_factor: exponential_avg_factor); | ||||
| /// <summary> | |||||
| /// Normalizes a tensor by `mean` and `variance`, and applies (optionally) a`scale` \\(\gamma\\) to it, as well as an `offset` \\(\beta\\). | |||||
| /// </summary> | |||||
| /// <param name="x">A floating point tensor.</param> | |||||
| /// <param name="mean">A mean `Tensor`.</param> | |||||
| /// <param name="variance">A variance `Tensor`.</param> | |||||
| /// <param name="offset"> An offset `Tensor`, often denoted \\(\beta\\) in equations, or NULL. If present, will be added to the normalized tensor.</param> | |||||
| /// <param name="scale"> A scale `Tensor`, often denoted \\(\gamma\\) in equations, or NULL. If present, the scale is applied to the normalized tensor.</param> | |||||
| /// <param name="variance_epsilon"> A small float number to avoid dividing by 0.</param> | |||||
| /// <param name="name">A name for this operation.</param> | |||||
| /// <returns>the normalized, scaled, offset tensor.</returns> | |||||
| public Tensor batch_normalization(Tensor x, | |||||
| Tensor mean, | |||||
| Tensor variance, | |||||
| Tensor offset, | |||||
| Tensor scale, | |||||
| float variance_epsilon, | |||||
| string name = null) => nn_impl.batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name); | |||||
| public Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null) | public Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null) | ||||
| => nn_ops.max_pool(value, ksize, strides, padding, data_format: data_format, name: name); | => nn_ops.max_pool(value, ksize, strides, padding, data_format: data_format, name: name); | ||||
| @@ -31,6 +31,6 @@ namespace Tensorflow | |||||
| public Tensor reshape(Tensor tensor, | public Tensor reshape(Tensor tensor, | ||||
| object[] shape, | object[] shape, | ||||
| string name = null) | string name = null) | ||||
| => gen_array_ops.reshape(tensor, ops.convert_to_tensor(shape), name); | |||||
| => array_ops.reshape(tensor, shape, name); | |||||
| } | } | ||||
| } | } | ||||
| @@ -68,20 +68,27 @@ namespace Tensorflow | |||||
| /// <param name="name">A name for the operation (optional)</param> | /// <param name="name">A name for the operation (optional)</param> | ||||
| /// <returns>if num_or_size_splits is a scalar returns num_or_size_splits Tensor objects; | /// <returns>if num_or_size_splits is a scalar returns num_or_size_splits Tensor objects; | ||||
| /// if num_or_size_splits is a 1-D Tensor returns num_or_size_splits.get_shape[0] Tensor objects resulting from splitting value.</returns> | /// if num_or_size_splits is a 1-D Tensor returns num_or_size_splits.get_shape[0] Tensor objects resulting from splitting value.</returns> | ||||
| public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null) | |||||
| public Tensor[] split(Tensor value, int num_split, Axis axis, string name = null) | |||||
| => array_ops.split( | => array_ops.split( | ||||
| value: value, | value: value, | ||||
| num_split: num_split, | |||||
| num_or_size_splits: num_split, | |||||
| axis: axis, | axis: axis, | ||||
| name: name); | name: name); | ||||
| public Tensor[] split(Tensor value, int num_split, int axis, string name = null) | |||||
| public Tensor[] split(Tensor value, int[] num_split, Axis axis, string name = null) | |||||
| => array_ops.split( | => array_ops.split( | ||||
| value: value, | value: value, | ||||
| num_split: num_split, | |||||
| num_or_size_splits: num_split, | |||||
| axis: axis, | axis: axis, | ||||
| name: name); | name: name); | ||||
| //public Tensor[] split(Tensor value, int num_split, Axis axis, string name = null) | |||||
| // => array_ops.split( | |||||
| // value: value, | |||||
| // num_or_size_splits: num_split, | |||||
| // axis: axis, | |||||
| // name: name); | |||||
| public Tensor ensure_shape(Tensor x, Shape shape, string name = null) | public Tensor ensure_shape(Tensor x, Shape shape, string name = null) | ||||
| { | { | ||||
| return gen_ops.ensure_shape(x, shape, name); | return gen_ops.ensure_shape(x, shape, name); | ||||
| @@ -23,7 +23,7 @@ namespace Tensorflow | |||||
| => gen_array_ops.tile(input, multiples, name); | => gen_array_ops.tile(input, multiples, name); | ||||
| public Tensor tile(Tensor input, object[] multiples, string name = null) | public Tensor tile(Tensor input, object[] multiples, string name = null) | ||||
| => gen_array_ops.tile(input, ops.convert_to_tensor(multiples), name); | |||||
| => array_ops.tile(input, constant_op.constant(shape_utils.from_object_array(multiples).dims), name); | |||||
| public Tensor tile(Tensor input, Shape multiples, string name = null) | public Tensor tile(Tensor input, Shape multiples, string name = null) | ||||
| { | { | ||||
| @@ -486,7 +486,28 @@ namespace Tensorflow | |||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| } | } | ||||
| } | } | ||||
| public static NDArray GetFlattenArray(NDArray x) | |||||
| { | |||||
| switch (x.GetDataType()) | |||||
| { | |||||
| case TF_DataType.TF_FLOAT: | |||||
| x = x.ToArray<float>(); | |||||
| break; | |||||
| case TF_DataType.TF_DOUBLE: | |||||
| x = x.ToArray<double>(); | |||||
| break; | |||||
| case TF_DataType.TF_INT16: | |||||
| case TF_DataType.TF_INT32: | |||||
| x = x.ToArray<int>(); | |||||
| break; | |||||
| case TF_DataType.TF_INT64: | |||||
| x = x.ToArray<long>(); | |||||
| break; | |||||
| default: | |||||
| break; | |||||
| } | |||||
| return x; | |||||
| } | |||||
| public static TF_DataType GetDataType(this object data) | public static TF_DataType GetDataType(this object data) | ||||
| { | { | ||||
| var type = data.GetType(); | var type = data.GetType(); | ||||
| @@ -503,7 +524,7 @@ namespace Tensorflow | |||||
| case Tensors tensors: | case Tensors tensors: | ||||
| return tensors.dtype; | return tensors.dtype; | ||||
| case IEnumerable<Tensor> tensors: | case IEnumerable<Tensor> tensors: | ||||
| return tensors.First().dtype; | |||||
| return tensors.Where(x => x is not null).First().dtype; | |||||
| case RefVariable variable: | case RefVariable variable: | ||||
| return variable.dtype; | return variable.dtype; | ||||
| case ResourceVariable variable: | case ResourceVariable variable: | ||||
| @@ -3,16 +3,16 @@ using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.Extensions | |||||
| namespace Tensorflow.Common.Extensions | |||||
| { | { | ||||
| public static class JObjectExtensions | public static class JObjectExtensions | ||||
| { | { | ||||
| public static T? TryGetOrReturnNull<T>(this JObject obj, string key) | public static T? TryGetOrReturnNull<T>(this JObject obj, string key) | ||||
| { | { | ||||
| var res = obj[key]; | var res = obj[key]; | ||||
| if(res is null) | |||||
| if (res is null) | |||||
| { | { | ||||
| return default(T); | |||||
| return default; | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -0,0 +1,38 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Common.Extensions | |||||
| { | |||||
| public static class LinqExtensions | |||||
| { | |||||
| #if NETSTANDARD2_0 | |||||
| public static IEnumerable<T> TakeLast<T>(this IEnumerable<T> sequence, int count) | |||||
| { | |||||
| return sequence.Skip(sequence.Count() - count); | |||||
| } | |||||
| public static IEnumerable<T> SkipLast<T>(this IEnumerable<T> sequence, int count) | |||||
| { | |||||
| return sequence.Take(sequence.Count() - count); | |||||
| } | |||||
| #endif | |||||
| public static Tensors ToTensors(this Tensor[] tensors) | |||||
| { | |||||
| return new Tensors(tensors); | |||||
| } | |||||
| public static Tensors ToTensors(this IList<Tensor> tensors) | |||||
| { | |||||
| return new Tensors(tensors); | |||||
| } | |||||
| public static void Deconstruct<T1, T2, T3>(this (T1, T2, T3) values, out T1 first, out T2 second, out T3 third) | |||||
| { | |||||
| first = values.Item1; | |||||
| second = values.Item2; | |||||
| third = values.Item3; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,33 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Common.Types; | |||||
| namespace Tensorflow.Common.Extensions | |||||
| { | |||||
| public static class NestExtensions | |||||
| { | |||||
| public static Tensors ToTensors(this INestable<Tensor> tensors) | |||||
| { | |||||
| return new Tensors(tensors.AsNest()); | |||||
| } | |||||
| public static Tensors? ToTensors(this Nest<Tensor> tensors) | |||||
| { | |||||
| return Tensors.FromNest(tensors); | |||||
| } | |||||
| /// <summary> | |||||
| /// If the nested object is already a nested type, this function could reduce it. | |||||
| /// For example, `Nest[Nest[T]]` can be reduced to `Nest[T]`. | |||||
| /// </summary> | |||||
| /// <typeparam name="TIn"></typeparam> | |||||
| /// <typeparam name="TOut"></typeparam> | |||||
| /// <param name="input"></param> | |||||
| /// <returns></returns> | |||||
| public static Nest<TOut> ReduceTo<TIn, TOut>(this INestStructure<TIn> input) where TIn: INestStructure<TOut> | |||||
| { | |||||
| return Nest<TOut>.ReduceFrom(input); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,20 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Common.Types | |||||
| { | |||||
| /// <summary> | |||||
| /// This is a temp solution, which should be removed after refactoring `Tensors` | |||||
| /// </summary> | |||||
| [Obsolete] | |||||
| public class FakeTensorByTensorArray: Tensor | |||||
| { | |||||
| public TensorArray TensorArray { get; set; } | |||||
| public FakeTensorByTensorArray(TensorArray array) | |||||
| { | |||||
| TensorArray = array; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,69 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Diagnostics; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Common.Types | |||||
| { | |||||
| public class GeneralizedTensorShape: Nest<Shape> | |||||
| { | |||||
| public GeneralizedTensorShape(Shape value, string? name = null) | |||||
| { | |||||
| NodeValue = value; | |||||
| NestType = NestType.Node; | |||||
| } | |||||
| public GeneralizedTensorShape(IEnumerable<Shape> values, string? name = null) | |||||
| { | |||||
| ListValue = values.Select(s => new Nest<Shape>(s) as INestStructure<Shape>).ToList(); | |||||
| Name = name; | |||||
| NestType = NestType.List; | |||||
| } | |||||
| public GeneralizedTensorShape(Dictionary<string, Shape> value, string? name = null) | |||||
| { | |||||
| DictValue = value.ToDictionary(x => x.Key, x => new Nest<Shape>(x.Value) as INestStructure<Shape>); | |||||
| Name = name; | |||||
| NestType = NestType.Dictionary; | |||||
| } | |||||
| public GeneralizedTensorShape(Nest<Shape> other) | |||||
| { | |||||
| NestType = other.NestType; | |||||
| NodeValue = other.NodeValue; | |||||
| DictValue = other.DictValue; | |||||
| ListValue = other.ListValue; | |||||
| Name = other.Name; | |||||
| } | |||||
| public Shape ToSingleShape() | |||||
| { | |||||
| var shapes = Flatten().ToList(); | |||||
| if (shapes.Count != 1) | |||||
| { | |||||
| throw new ValueError("The generalized shape contains more than 1 dim."); | |||||
| } | |||||
| return shapes[0]; | |||||
| } | |||||
| public long ToNumber() | |||||
| { | |||||
| var shapes = Flatten().ToList(); | |||||
| if (shapes.Count != 1 || shapes[0].ndim != 1) | |||||
| { | |||||
| throw new ValueError("The generalized shape contains more than 1 dim."); | |||||
| } | |||||
| return shapes[0].dims[0]; | |||||
| } | |||||
| public INestStructure<TensorShapeConfig> ToTensorShapeConfigs() | |||||
| { | |||||
| return MapStructure(s => new TensorShapeConfig() { Items = s.dims.Select<long, long?>(x => x == -1 ? null : x).ToArray() }); | |||||
| } | |||||
| public static implicit operator GeneralizedTensorShape(Shape shape) | |||||
| { | |||||
| return new GeneralizedTensorShape(shape); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,40 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Common.Types | |||||
| { | |||||
| /// <summary> | |||||
| /// This interface indicates that a class may have a nested structure and provide | |||||
| /// methods to manipulate with the structure. | |||||
| /// </summary> | |||||
| public interface INestStructure<T>: INestable<T> | |||||
| { | |||||
| NestType NestType { get; } | |||||
| /// <summary> | |||||
| /// The item count of depth 1 of the nested structure. | |||||
| /// For example, [1, 2, [3, 4, 5]] has ShallowNestedCount = 3. | |||||
| /// </summary> | |||||
| int ShallowNestedCount { get; } | |||||
| /// <summary> | |||||
| /// The total item count of depth 1 of the nested structure. | |||||
| /// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5. | |||||
| /// </summary> | |||||
| int TotalNestedCount { get; } | |||||
| /// <summary> | |||||
| /// Flatten the Nestable object. Node that if the object contains only one value, | |||||
| /// it will be flattened to an enumerable with one element. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| IEnumerable<T> Flatten(); | |||||
| /// <summary> | |||||
| /// Construct a new object with the same nested structure. | |||||
| /// </summary> | |||||
| /// <typeparam name="TOut"></typeparam> | |||||
| /// <param name="func"></param> | |||||
| /// <returns></returns> | |||||
| INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,11 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Common.Types | |||||
| { | |||||
| public interface INestable<T> | |||||
| { | |||||
| Nest<T> AsNest(); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,21 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Common.Types | |||||
| { | |||||
| /// <summary> | |||||
| /// This interface is used when some corresponding python methods have optional args. | |||||
| /// For example, `Keras.Layer.Apply` generally takes three args as the inputs, while | |||||
| /// `Keras.Layer.RNN` takes more. Then when calling RNN, you should add `RnnOptionalArgs` | |||||
| /// as the parameter of the method. | |||||
| /// </summary> | |||||
| public interface IOptionalArgs | |||||
| { | |||||
| /// <summary> | |||||
| /// The identifier of the class. It is not an argument but only something to | |||||
| /// separate different OptionalArgs. | |||||
| /// </summary> | |||||
| string Identifier { get; } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,62 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Common.Types | |||||
| { | |||||
| public static class Nest | |||||
| { | |||||
| /// <summary> | |||||
| /// Pack the flat items to a nested sequence by the template. | |||||
| /// </summary> | |||||
| /// <typeparam name="T"></typeparam> | |||||
| /// <param name="template"></param> | |||||
| /// <param name="flatItems"></param> | |||||
| /// <returns></returns> | |||||
| public static Nest<TOut> PackSequenceAs<T, TOut>(INestable<T> template, TOut[] flatItems) | |||||
| { | |||||
| return template.AsNest().PackSequence(flatItems); | |||||
| } | |||||
| /// <summary> | |||||
| /// Pack the flat items to a nested sequence by the template. | |||||
| /// </summary> | |||||
| /// <typeparam name="T"></typeparam> | |||||
| /// <param name="template"></param> | |||||
| /// <param name="flatItems"></param> | |||||
| /// <returns></returns> | |||||
| public static Nest<T> PackSequenceAs<T>(INestable<T> template, List<T> flatItems) | |||||
| { | |||||
| return template.AsNest().PackSequence(flatItems.ToArray()); | |||||
| } | |||||
| /// <summary> | |||||
| /// Flatten the nested object. | |||||
| /// </summary> | |||||
| /// <typeparam name="T"></typeparam> | |||||
| /// <param name="nestedObject"></param> | |||||
| /// <returns></returns> | |||||
| public static IEnumerable<T> Flatten<T>(INestable<T> nestedObject) | |||||
| { | |||||
| return nestedObject.AsNest().Flatten(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Map the structure with specified function. | |||||
| /// </summary> | |||||
| /// <typeparam name="TIn"></typeparam> | |||||
| /// <typeparam name="TOut"></typeparam> | |||||
| /// <param name="func"></param> | |||||
| /// <param name="nestedObject"></param> | |||||
| /// <returns></returns> | |||||
| public static INestStructure<TOut> MapStructure<TIn, TOut>(Func<TIn, TOut> func, INestable<TIn> nestedObject) | |||||
| { | |||||
| return nestedObject.AsNest().MapStructure(func); | |||||
| } | |||||
| public static bool IsNested<T>(INestable<T> obj) | |||||
| { | |||||
| return obj.AsNest().IsNested(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,485 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Common.Extensions; | |||||
| namespace Tensorflow.Common.Types | |||||
| { | |||||
| public enum NestType | |||||
| { | |||||
| Empty, | |||||
| Node, | |||||
| List, | |||||
| Dictionary | |||||
| } | |||||
| /// <summary> | |||||
| /// A nested structure which may inclulde value, list and dictionary. | |||||
| /// Note that dictionary does not ensure the data order. When using it as IEnumerable, | |||||
| /// its order is depth-first. | |||||
| /// </summary> | |||||
| /// <typeparam name="T"></typeparam> | |||||
| public class Nest<T> : INestStructure<T>, IEnumerable<T> | |||||
| { | |||||
| private static readonly Nest<T> _empty = new Nest<T>() | |||||
| { | |||||
| NestType = NestType.Empty, | |||||
| }; | |||||
| public static Nest<T> Empty => _empty; | |||||
| public NestType NestType { get; protected set; } | |||||
| public string? Name { get; set; } | |||||
| public T? NodeValue { get; protected set; } | |||||
| public List<INestStructure<T>>? ListValue { get; protected set; } | |||||
| public Dictionary<string, INestStructure<T>>? DictValue { get; protected set; } | |||||
| public int ShallowNestedCount | |||||
| { | |||||
| get | |||||
| { | |||||
| if (NestType == NestType.Empty) | |||||
| { | |||||
| return 0; | |||||
| } | |||||
| else if (NestType == NestType.Node) | |||||
| { | |||||
| return 1; | |||||
| } | |||||
| else if (NestType == NestType.List) | |||||
| { | |||||
| return ListValue!.Count; | |||||
| } | |||||
| else // dict | |||||
| { | |||||
| return DictValue!.Count; | |||||
| } | |||||
| } | |||||
| } | |||||
| public int TotalNestedCount | |||||
| { | |||||
| get | |||||
| { | |||||
| return Flatten().Count(); | |||||
| } | |||||
| } | |||||
| protected Nest() { } | |||||
| public Nest(T value, string? name = null) | |||||
| { | |||||
| NodeValue = value; | |||||
| Name = name; | |||||
| NestType = NestType.Node; | |||||
| } | |||||
| public Nest(IEnumerable<INestStructure<T>> values, string? name = null) | |||||
| { | |||||
| ListValue = values.ToList(); | |||||
| Name = name; | |||||
| NestType = NestType.List; | |||||
| } | |||||
| public Nest(Dictionary<string, INestStructure<T>> value, string? name = null) | |||||
| { | |||||
| DictValue = value; | |||||
| Name = name; | |||||
| NestType = NestType.Dictionary; | |||||
| } | |||||
| public Nest(Nest<T> other) | |||||
| { | |||||
| NestType = other.NestType; | |||||
| NodeValue = other.NodeValue; | |||||
| DictValue = other.DictValue; | |||||
| ListValue = other.ListValue; | |||||
| Name = other.Name; | |||||
| } | |||||
| public virtual IEnumerable<T> Flatten() | |||||
| { | |||||
| return FlattenInternal(this); | |||||
| } | |||||
| public virtual INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||||
| { | |||||
| return MapStructureInternal(func); | |||||
| } | |||||
| /// <summary> | |||||
| /// Pack the flat items to a nested sequence by the template. | |||||
| /// </summary> | |||||
| /// <param name="flatItems"></param> | |||||
| /// <returns></returns> | |||||
| public virtual Nest<TOut> PackSequence<TOut>(TOut[] flatItems) | |||||
| { | |||||
| if(flatItems.Length == 0) | |||||
| { | |||||
| return Nest<TOut>.Empty; | |||||
| } | |||||
| int index = 0; | |||||
| return PackSequenceInternal(this, flatItems, ref index); | |||||
| } | |||||
| private static Nest<TOut> PackSequenceInternal<TOut>(Nest<T> template, TOut[] flatItems, ref int index) | |||||
| { | |||||
| if(template.NestType == NestType.Node) | |||||
| { | |||||
| if(index >= flatItems.Length) | |||||
| { | |||||
| throw new InvalidArgumentError("The template and flat items are not matched."); | |||||
| } | |||||
| return new Nest<TOut>(flatItems[index++]); | |||||
| } | |||||
| else if(template.NestType == NestType.List) | |||||
| { | |||||
| List<Nest<TOut>> nestedObjects = new List<Nest<TOut>>(); | |||||
| for (int i = 0; i < template.ListValue!.Count; i++) | |||||
| { | |||||
| nestedObjects.Add(PackSequenceInternal(template.ListValue![i].AsNest(), flatItems, ref index)); | |||||
| } | |||||
| return new Nest<TOut>(nestedObjects); | |||||
| } | |||||
| else if(template.NestType == NestType.Node) | |||||
| { | |||||
| Dictionary<string, INestStructure<TOut>> dict = new Dictionary<string, INestStructure<TOut>>(); | |||||
| foreach(var (key, value) in template.DictValue!) | |||||
| { | |||||
| dict[key] = PackSequenceInternal(value.AsNest(), flatItems, ref index); | |||||
| } | |||||
| return new Nest<TOut>(dict); | |||||
| } | |||||
| // Consider Empty as invalid type. | |||||
| throw new InvalidArgumentError("When using `PackSequenceAs`, the template cannot contain empty node."); | |||||
| } | |||||
| public virtual Nest<T> AsNest() | |||||
| { | |||||
| return this; | |||||
| } | |||||
| public virtual Nest<T> MergeWith(Nest<T>? other) | |||||
| { | |||||
| if(other is null || other == Nest<T>.Empty) | |||||
| { | |||||
| return this; | |||||
| } | |||||
| if(this == Nest<T>.Empty) | |||||
| { | |||||
| return other; | |||||
| } | |||||
| if(NestType == NestType.Node && other.NestType == NestType.Node) | |||||
| { | |||||
| return new Nest<T>(new Nest<T>[] { this, other }); | |||||
| } | |||||
| else if(NestType == NestType.List && other.NestType == NestType.List) | |||||
| { | |||||
| return new Nest<T>(this.ListValue!.Concat(other.ListValue!)); | |||||
| } | |||||
| else if(NestType == NestType.Dictionary && other.NestType == NestType.Dictionary) | |||||
| { | |||||
| return new Nest<T>(this.DictValue!.Concat(other.DictValue!).ToDictionary(x => x.Key, x => x.Value)); | |||||
| } | |||||
| else | |||||
| { | |||||
| return new Nest<T>(new Nest<T>[] { this, other }); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// To see if the nested object is really nested. Despite being called `Nest`, sometimes it's actually not | |||||
| /// nested. For example, [1, 2, 3] is not nested, while [1, [2, 3]] is nested. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public bool IsNested() | |||||
| { | |||||
| if(NestType is NestType.Empty or NestType.Node) | |||||
| { | |||||
| return false; | |||||
| } | |||||
| else if(NestType is NestType.List) | |||||
| { | |||||
| return ListValue!.Count > 0; | |||||
| } | |||||
| else | |||||
| { | |||||
| return DictValue!.Count > 0; | |||||
| } | |||||
| } | |||||
| [Obsolete("The indexer of Tensors is not encouraged because it leads to unclear meanings.")] | |||||
| public T this[int index] | |||||
| { | |||||
| get | |||||
| { | |||||
| bool success = FindInternal(this, index, out var result); | |||||
| if (success) | |||||
| { | |||||
| return result; | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new IndexOutOfRangeException(); | |||||
| } | |||||
| } | |||||
| set | |||||
| { | |||||
| bool success = SetInternal(this, index, value); | |||||
| if (!success) | |||||
| { | |||||
| throw new IndexOutOfRangeException(); | |||||
| } | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// If the existing nested structure if of type `Nest[INestStructure[T]]`, we can reduce it | |||||
| /// to `Nest[T]`. | |||||
| /// </summary> | |||||
| /// <typeparam name="TOut"></typeparam> | |||||
| /// <param name="input"></param> | |||||
| /// <returns></returns> | |||||
| public static Nest<T> ReduceFrom<TOut>(INestStructure<TOut> input) where TOut: INestStructure<T> | |||||
| { | |||||
| var nested = input.AsNest(); | |||||
| return ReduceInternal(nested).AsNest(); | |||||
| } | |||||
| private static INestStructure<T> ReduceInternal<TOut>(Nest<TOut> node) where TOut : INestStructure<T> | |||||
| { | |||||
| if(node.NestType == NestType.Empty) | |||||
| { | |||||
| return Nest<T>.Empty; | |||||
| } | |||||
| else if(node.NestType == NestType.Node) | |||||
| { | |||||
| return node.NodeValue!.AsNest(); | |||||
| } | |||||
| else if(node.NestType == NestType.List) | |||||
| { | |||||
| return new Nest<T>(node.ListValue!.Select(x => ReduceInternal(x.AsNest()))); | |||||
| } | |||||
| else // Dictionary type | |||||
| { | |||||
| return new Nest<T>(node.DictValue!.ToDictionary(x => x.Key, x => ReduceInternal(x.Value.AsNest()))); | |||||
| } | |||||
| } | |||||
| private static bool FindInternal(Nest<T> node, int index, out T? result) | |||||
| { | |||||
| if (node.NestType == NestType.Node) | |||||
| { | |||||
| if(index == 0) | |||||
| { | |||||
| result = node.NodeValue!; | |||||
| return true; | |||||
| } | |||||
| result = default(T); | |||||
| return false; | |||||
| } | |||||
| else if (node.NestType == NestType.List) | |||||
| { | |||||
| foreach (var item in node.ListValue!) | |||||
| { | |||||
| if(index == 0) | |||||
| { | |||||
| return FindInternal(item.AsNest(), index, out result); | |||||
| } | |||||
| index--; | |||||
| } | |||||
| result = default(T); | |||||
| return false; | |||||
| } | |||||
| else if(node.NestType == NestType.Dictionary) | |||||
| { | |||||
| foreach (var item in node.DictValue!.Values) | |||||
| { | |||||
| if (index == 0) | |||||
| { | |||||
| return FindInternal(item.AsNest(), index, out result); | |||||
| } | |||||
| index--; | |||||
| } | |||||
| result = default(T); | |||||
| return false; | |||||
| } | |||||
| else | |||||
| { | |||||
| result = default(T); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| private static bool SetInternal(Nest<T> node, int index, T newValue) | |||||
| { | |||||
| if (node.NestType == NestType.Node) | |||||
| { | |||||
| if (index == 0) | |||||
| { | |||||
| node.NodeValue = newValue; | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| else if (node.NestType == NestType.List) | |||||
| { | |||||
| foreach (var item in node.ListValue!) | |||||
| { | |||||
| if (index == 0) | |||||
| { | |||||
| return SetInternal(item.AsNest(), index, newValue); | |||||
| } | |||||
| index--; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| else if (node.NestType == NestType.Dictionary) | |||||
| { | |||||
| foreach (var item in node.DictValue!.Values) | |||||
| { | |||||
| if (index == 0) | |||||
| { | |||||
| return SetInternal(item.AsNest(), index, newValue); | |||||
| } | |||||
| index--; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| else | |||||
| { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| private static IEnumerable<T> FlattenInternal(Nest<T> node) | |||||
| { | |||||
| if (node.NestType == NestType.Node) | |||||
| { | |||||
| yield return node.NodeValue!; | |||||
| } | |||||
| else if (node.NestType == NestType.List) | |||||
| { | |||||
| foreach (var item in node.ListValue!) | |||||
| { | |||||
| foreach(var val in FlattenInternal(item.AsNest())) | |||||
| { | |||||
| yield return val; | |||||
| } | |||||
| } | |||||
| } | |||||
| else if (node.NestType == NestType.Dictionary) | |||||
| { | |||||
| foreach (var item in node.DictValue!.Values) | |||||
| { | |||||
| foreach (var val in FlattenInternal(item.AsNest())) | |||||
| { | |||||
| yield return val; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| private Nest<TOut> MapStructureInternal<TOut>(Func<T, TOut> func) | |||||
| { | |||||
| if (NestType == NestType.Node) | |||||
| { | |||||
| return new Nest<TOut>(func(NodeValue!)); | |||||
| } | |||||
| else if (NestType == NestType.List) | |||||
| { | |||||
| List<Nest<TOut>> outs = new List<Nest<TOut>>(); | |||||
| foreach (var item in ListValue!) | |||||
| { | |||||
| outs.Add(item.AsNest().MapStructureInternal(func)); | |||||
| } | |||||
| return new Nest<TOut>(outs); | |||||
| } | |||||
| else if (NestType == NestType.Dictionary) | |||||
| { | |||||
| Dictionary<string, INestStructure<TOut>> outs = new Dictionary<string, INestStructure<TOut>>(); | |||||
| foreach (var (key, value) in DictValue!) | |||||
| { | |||||
| outs.Add(key, value.AsNest().MapStructureInternal(func)); | |||||
| } | |||||
| return new Nest<TOut>(outs); | |||||
| } | |||||
| else | |||||
| { | |||||
| return Nest<TOut>.Empty; | |||||
| } | |||||
| } | |||||
| public IEnumerator<T> GetEnumerator() | |||||
| { | |||||
| return Flatten().GetEnumerator(); | |||||
| } | |||||
| IEnumerator IEnumerable.GetEnumerator() | |||||
| { | |||||
| return GetEnumerator(); | |||||
| } | |||||
| public override string ToString() | |||||
| { | |||||
| StringBuilder sb = new StringBuilder(); | |||||
| sb.Append("("); | |||||
| WriteString(this, sb); | |||||
| sb.Append(")"); | |||||
| return sb.ToString(); | |||||
| } | |||||
| private static void WriteString(Nest<T> node, StringBuilder sb) | |||||
| { | |||||
| if (!string.IsNullOrEmpty(node.Name)) | |||||
| { | |||||
| sb.Append($"{node.Name}: "); | |||||
| } | |||||
| if (node.NestType == NestType.Node) | |||||
| { | |||||
| sb.Append(node.NodeValue!.ToString()); | |||||
| } | |||||
| else if (node.NestType == NestType.List) | |||||
| { | |||||
| sb.Append("["); | |||||
| for(int i = 0; i < node.ListValue!.Count; i++) | |||||
| { | |||||
| WriteString(node.ListValue![i].AsNest(), sb); | |||||
| if(i != node.ListValue!.Count - 1) | |||||
| { | |||||
| sb.Append(", "); | |||||
| } | |||||
| } | |||||
| sb.Append("]"); | |||||
| } | |||||
| else if (node.NestType == NestType.Dictionary) | |||||
| { | |||||
| sb.Append("{"); | |||||
| int count = node.DictValue!.Count; | |||||
| int i = 0; | |||||
| foreach (var (key, value) in node.DictValue!) | |||||
| { | |||||
| sb.Append($"{key}: "); | |||||
| WriteString(value.AsNest(), sb); | |||||
| if (i != count - 1) | |||||
| { | |||||
| sb.Append(", "); | |||||
| } | |||||
| i++; | |||||
| } | |||||
| sb.Append("}"); | |||||
| } | |||||
| else | |||||
| { | |||||
| sb.Append("<empty>"); | |||||
| } | |||||
| } | |||||
| public static implicit operator Nest<T>((INestStructure<T>, INestStructure<T>) inputs) | |||||
| { | |||||
| return new Nest<T>(new INestStructure<T>[] { inputs.Item1, inputs.Item2 }); | |||||
| } | |||||
| public static implicit operator Nest<T>((INestStructure<T>, INestStructure<T>, INestStructure<T>) inputs) | |||||
| { | |||||
| return new Nest<T>(new INestStructure<T>[] { inputs.Item1, inputs.Item2, inputs.Item3 }); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,103 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Common.Types | |||||
| { | |||||
| public class NestDictionary<TKey, TValue> : INestStructure<TValue>, IDictionary<TKey, TValue> where TKey : notnull | |||||
| { | |||||
| public NestType NestType => NestType.Dictionary; | |||||
| public IDictionary<TKey, TValue> Value { get; set; } | |||||
| public int ShallowNestedCount => Values.Count; | |||||
| public int TotalNestedCount => Values.Count; | |||||
| public NestDictionary(IDictionary<TKey, TValue> dict) | |||||
| { | |||||
| Value = dict; | |||||
| } | |||||
| public IEnumerable<TValue> Flatten() | |||||
| { | |||||
| return Value.Select(x => x.Value); | |||||
| } | |||||
| public INestStructure<TOut> MapStructure<TOut>(Func<TValue, TOut> func) | |||||
| { | |||||
| return new NestList<TOut>(Value.Select(x => func(x.Value))); | |||||
| } | |||||
| public Nest<TValue> AsNest() | |||||
| { | |||||
| return new Nest<TValue>(Value.Values.Select(x => new Nest<TValue>(x))); | |||||
| } | |||||
| // Required IDictionary<TKey, TValue> members | |||||
| public int Count => Value.Count; | |||||
| public bool IsReadOnly => Value.IsReadOnly; | |||||
| public ICollection<TKey> Keys => Value.Keys; | |||||
| public ICollection<TValue> Values => Value.Values; | |||||
| public void Add(TKey key, TValue value) | |||||
| { | |||||
| Value.Add(key, value); | |||||
| } | |||||
| public void Add(KeyValuePair<TKey, TValue> item) | |||||
| { | |||||
| Value.Add(item); | |||||
| } | |||||
| public void Clear() | |||||
| { | |||||
| Value.Clear(); | |||||
| } | |||||
| public bool Contains(KeyValuePair<TKey, TValue> item) | |||||
| { | |||||
| return Value.Contains(item); | |||||
| } | |||||
| public bool ContainsKey(TKey key) | |||||
| { | |||||
| return Value.ContainsKey(key); | |||||
| } | |||||
| public void CopyTo(KeyValuePair<TKey, TValue>[] array, int arrayIndex) | |||||
| { | |||||
| Value.CopyTo(array, arrayIndex); | |||||
| } | |||||
| public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator() | |||||
| { | |||||
| return Value.GetEnumerator(); | |||||
| } | |||||
| IEnumerator IEnumerable.GetEnumerator() | |||||
| { | |||||
| return GetEnumerator(); | |||||
| } | |||||
| public bool Remove(TKey key) | |||||
| { | |||||
| return Value.Remove(key); | |||||
| } | |||||
| public bool Remove(KeyValuePair<TKey, TValue> item) | |||||
| { | |||||
| return Value.Remove(item); | |||||
| } | |||||
| public bool TryGetValue(TKey key, out TValue value) | |||||
| { | |||||
| return Value.TryGetValue(key, out value); | |||||
| } | |||||
| // Optional IDictionary<TKey, TValue> members | |||||
| public TValue this[TKey key] | |||||
| { | |||||
| get => Value[key]; | |||||
| set => Value[key] = value; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,53 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Common.Types | |||||
| { | |||||
| /// <summary> | |||||
| /// The implementation of a list that support nest structure, in which the depth is 1. | |||||
| /// </summary> | |||||
| /// <typeparam name="T"></typeparam> | |||||
| public sealed class NestList<T> : INestStructure<T>, IEnumerable<T> | |||||
| { | |||||
| public NestType NestType => NestType.List; | |||||
| public List<T> Values { get; set; } | |||||
| public int ShallowNestedCount => Values.Count; | |||||
| public int TotalNestedCount => Values.Count; | |||||
| public NestList(params T[] values) | |||||
| { | |||||
| Values = new List<T>(values); | |||||
| } | |||||
| public NestList(IEnumerable<T> values) | |||||
| { | |||||
| Values = new List<T>(values); | |||||
| } | |||||
| public IEnumerable<T> Flatten() | |||||
| { | |||||
| return Values; | |||||
| } | |||||
| public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||||
| { | |||||
| return new NestList<TOut>(Values.Select(x => func(x))); | |||||
| } | |||||
| public Nest<T> AsNest() | |||||
| { | |||||
| return new Nest<T>(Values.Select(x => new Nest<T>(x))); | |||||
| } | |||||
| // Enumerator implementation | |||||
| public IEnumerator<T> GetEnumerator() | |||||
| { | |||||
| return Values.GetEnumerator(); | |||||
| } | |||||
| IEnumerator IEnumerable.GetEnumerator() | |||||
| { | |||||
| return GetEnumerator(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,36 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Common.Types | |||||
| { | |||||
| /// <summary> | |||||
| /// A nested structure with only one element. | |||||
| /// </summary> | |||||
| /// <typeparam name="T"></typeparam> | |||||
| public class NestNode<T> : INestStructure<T> | |||||
| { | |||||
| public NestType NestType => NestType.Node; | |||||
| public T Value { get; set; } | |||||
| public int ShallowNestedCount => 1; | |||||
| public int TotalNestedCount => 1; | |||||
| public NestNode(T value) | |||||
| { | |||||
| Value = value; | |||||
| } | |||||
| public IEnumerable<T> Flatten() | |||||
| { | |||||
| yield return Value; | |||||
| } | |||||
| public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||||
| { | |||||
| return new NestNode<TOut>(func(Value)); | |||||
| } | |||||
| public Nest<T> AsNest() | |||||
| { | |||||
| return new Nest<T>(Value); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -3,7 +3,7 @@ using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| namespace Tensorflow.Keras.Saving | |||||
| namespace Tensorflow.Common.Types | |||||
| { | { | ||||
| public class TensorShapeConfig | public class TensorShapeConfig | ||||
| { | { | ||||
| @@ -161,8 +161,8 @@ namespace Tensorflow | |||||
| break; | break; | ||||
| } | } | ||||
| yield return (new Tensors(results.Take(FirstInputTensorCount)), results.Length == FirstInputTensorCount ? | |||||
| null : new Tensors(results.Skip(FirstInputTensorCount))); | |||||
| yield return (new Tensors(results.Take(FirstInputTensorCount).ToArray()), results.Length == FirstInputTensorCount ? | |||||
| null : new Tensors(results.Skip(FirstInputTensorCount).ToArray())); | |||||
| } | } | ||||
| } | } | ||||
| @@ -352,13 +352,19 @@ namespace Tensorflow.Eager | |||||
| c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value)); | c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value)); | ||||
| break; | break; | ||||
| case TF_AttrType.TF_ATTR_SHAPE: | case TF_AttrType.TF_ATTR_SHAPE: | ||||
| var dims = (value as long[]).ToArray(); | |||||
| long[] dims; | |||||
| if (value is Shape shape) dims = shape.dims.ToArray(); | |||||
| else if (value is long[] longs) dims = longs.ToArray(); | |||||
| else if (value is int[] ints) dims = ints.Select(x => (long)x).ToArray(); | |||||
| else dims = ((long[])value).ToArray(); | |||||
| c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status); | c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status); | ||||
| status.Check(true); | status.Check(true); | ||||
| break; | break; | ||||
| case TF_AttrType.TF_ATTR_FUNC: | case TF_AttrType.TF_ATTR_FUNC: | ||||
| if (value is ConcreteFunction func) | if (value is ConcreteFunction func) | ||||
| c_api.TFE_OpSetAttrFunctionName(op, key, func.func_graph.FuncName, func.func_graph.FuncName.Length); | c_api.TFE_OpSetAttrFunctionName(op, key, func.func_graph.FuncName, func.func_graph.FuncName.Length); | ||||
| else if(value is string str) | |||||
| c_api.TFE_OpSetAttrFunctionName(op, key, str, str.Length); | |||||
| else | else | ||||
| throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC"); | throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC"); | ||||
| break; | break; | ||||
| @@ -65,7 +65,7 @@ namespace Tensorflow.Eager | |||||
| { | { | ||||
| outgrad_vec = output_gradients.ToList(); | outgrad_vec = output_gradients.ToList(); | ||||
| } | } | ||||
| var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, false); | |||||
| var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, true); | |||||
| bool unconnected_gradients_zero = unconnected_gradients == "zero"; | bool unconnected_gradients_zero = unconnected_gradients == "zero"; | ||||
| @@ -137,7 +137,6 @@ namespace Tensorflow.Eager | |||||
| { | { | ||||
| dims[i] = c_api.TFE_TensorHandleDim(handle, i, status); | dims[i] = c_api.TFE_TensorHandleDim(handle, i, status); | ||||
| } | } | ||||
| Shape tensor_shape = new(dims); | |||||
| if(status.Code != TF_Code.TF_OK) | if(status.Code != TF_Code.TF_OK) | ||||
| { | { | ||||
| @@ -145,6 +144,7 @@ namespace Tensorflow.Eager | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| Shape tensor_shape = new(dims); | |||||
| return new TapeTensor(id, dtype, tensor_shape); | return new TapeTensor(id, dtype, tensor_shape); | ||||
| } | } | ||||
| } | } | ||||
| @@ -173,8 +173,12 @@ namespace Tensorflow.Eager | |||||
| return dtype == dtypes.variant || dtype == dtypes.resource; | return dtype == dtypes.variant || dtype == dtypes.resource; | ||||
| } | } | ||||
| bool ListContainNone(long[] list) | |||||
| bool ListContainNone(long[]? list) | |||||
| { | { | ||||
| if(list is null) | |||||
| { | |||||
| return true; | |||||
| } | |||||
| int len = list.Length; | int len = list.Length; | ||||
| if(len == 0) | if(len == 0) | ||||
| { | { | ||||
| @@ -10,6 +10,11 @@ namespace Tensorflow.Eager | |||||
| var str = NDArrayRender.ToString(nd); | var str = NDArrayRender.ToString(nd); | ||||
| return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}"; | return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}"; | ||||
| } | } | ||||
| public string ToString(int maxLength) | |||||
| { | |||||
| var nd = new NDArray(this); | |||||
| var str = NDArrayRender.ToString(nd, maxLength); | |||||
| return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}"; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,25 @@ | |||||
| using Tensorflow; | |||||
| internal static class GraphOnlyOps | |||||
| { | |||||
| /// <summary> | |||||
| /// Graph-only version of tf.compat.v1.placeholder(), for internal use only. | |||||
| /// </summary> | |||||
| /// <param name="dtyype"></param> | |||||
| /// <param name="shape"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| internal static Tensor graph_placeholder(TF_DataType dtype, Shape shape, string? name = null) | |||||
| { | |||||
| var dtype_value = new AttrValue() { Type = dtype.as_datatype_enum() }; | |||||
| var shape_value = new AttrValue() { Shape = shape.as_proto() }; | |||||
| var g = ops.get_default_graph(); | |||||
| Dictionary<string, AttrValue> attrs = new(); | |||||
| attrs["dtype"] = dtype_value; | |||||
| attrs["shape"] = shape_value; | |||||
| var op = g.create_op("Placeholder", new Tensor[0], new TF_DataType[] { dtype }, | |||||
| new TF_DataType[0], attrs: attrs, name: name); | |||||
| var result = op.outputs[0]; | |||||
| return result; | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,19 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Exceptions | |||||
| { | |||||
| public class NotOkStatusException : TensorflowException | |||||
| { | |||||
| public NotOkStatusException() : base() | |||||
| { | |||||
| } | |||||
| public NotOkStatusException(string message) : base(message) | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -49,12 +49,25 @@ namespace Tensorflow.Framework | |||||
| public static implicit operator Tensor(IndexedSlices indexedSlices) | public static implicit operator Tensor(IndexedSlices indexedSlices) | ||||
| { | { | ||||
| return indexedSlices.values; | |||||
| return _indexed_slices_to_tensor(indexedSlices); | |||||
| } | } | ||||
| public static implicit operator IndexedSlices(Tensor tensor) | public static implicit operator IndexedSlices(Tensor tensor) | ||||
| { | { | ||||
| return tensor.Tag as IndexedSlices; | return tensor.Tag as IndexedSlices; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Converts an IndexedSlices object `value` to a Tensor. | |||||
| /// </summary> | |||||
| /// <param name="indexedSlices"></param> | |||||
| /// <param name="dtype"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <param name="as_ref"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor _indexed_slices_to_tensor(IndexedSlices indexedSlices, TF_DataType dtype = TF_DataType.DtInvalid, String name = "", bool as_ref = false) | |||||
| { | |||||
| return gen_math_ops.unsorted_segment_sum(indexedSlices.values, indexedSlices.indices, indexedSlices.dense_shape.slice(0)); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,4 +1,5 @@ | |||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Eager; | |||||
| namespace Tensorflow.Framework.Models | namespace Tensorflow.Framework.Models | ||||
| { | { | ||||
| @@ -24,5 +25,17 @@ namespace Tensorflow.Framework.Models | |||||
| shapes.Insert(0, dim); | shapes.Insert(0, dim); | ||||
| return new TensorSpec(shapes.ToArray(), _dtype); | return new TensorSpec(shapes.ToArray(), _dtype); | ||||
| } | } | ||||
| public static TensorSpec FromTensor(Tensor tensor, string? name = null) | |||||
| { | |||||
| if(tensor is EagerTensor) | |||||
| { | |||||
| return new TensorSpec(tensor.shape, tensor.dtype, name); | |||||
| } | |||||
| else | |||||
| { | |||||
| return new TensorSpec(tensor.shape, tensor.dtype, name ?? tensor.name); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,89 @@ | |||||
| using Tensorflow.Graphs; | |||||
| namespace Tensorflow.Framework | |||||
| { | |||||
| internal static class auto_control_deps_utils | |||||
| { | |||||
| public static readonly string READ_ONLY_RESOURCE_INPUTS_ATTR = "_read_only_resource_inputs"; | |||||
| public static List<int> get_read_only_resource_input_indices_graph(FuncGraph func_graph) | |||||
| { | |||||
| List<int> result = new List<int>(); | |||||
| // A cache to store the read only resource inputs of an Op. | |||||
| // Operation -> ObjectIdentitySet of resource handles. | |||||
| Dictionary<Operation, HashSet<Tensor>> opReadOnlyResourceInputs = | |||||
| new Dictionary<Operation, HashSet<Tensor>>(); | |||||
| for (int inputIndex = 0; inputIndex < func_graph.Inputs.Length; inputIndex++) | |||||
| { | |||||
| Tensor t = func_graph.Inputs[inputIndex]; | |||||
| if (t.dtype != dtypes.resource) | |||||
| continue; | |||||
| bool readOnly = true; | |||||
| foreach (var op in t.consumers()) | |||||
| { | |||||
| if (opReadOnlyResourceInputs.ContainsKey(op)) | |||||
| { | |||||
| if (!opReadOnlyResourceInputs[op].Contains(t)) | |||||
| { | |||||
| readOnly = false; | |||||
| break; | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| List<int> indices = _get_read_only_resource_input_indices_op(op); | |||||
| opReadOnlyResourceInputs[op] = new HashSet<Tensor>( | |||||
| indices.Select(i => op.inputs[i])); | |||||
| if (!opReadOnlyResourceInputs[op].Contains(t)) | |||||
| { | |||||
| readOnly = false; | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (readOnly) | |||||
| result.Add(inputIndex); | |||||
| } | |||||
| return result; | |||||
| } | |||||
| private static List<int> _get_read_only_resource_input_indices_op(Operation op) | |||||
| { | |||||
| // ignore the RESOURCE_READ_OPS | |||||
| int[] read_only_input_indices; | |||||
| try | |||||
| { | |||||
| read_only_input_indices = op.get_attr<int[]>(READ_ONLY_RESOURCE_INPUTS_ATTR); | |||||
| } | |||||
| catch (InvalidArgumentError) | |||||
| { | |||||
| return new List<int>(); | |||||
| } | |||||
| int read_only_index = 0; | |||||
| List<int> result = new(); | |||||
| for (int i = 0; i < op.inputs.Length; i++) | |||||
| { | |||||
| if (read_only_index >= read_only_input_indices.Length) | |||||
| { | |||||
| break; | |||||
| } | |||||
| if (op.inputs[i].dtype != dtypes.resource) | |||||
| { | |||||
| continue; | |||||
| } | |||||
| if (read_only_index < read_only_input_indices.Length && i == read_only_input_indices[read_only_index]) | |||||
| { | |||||
| result.Add(i); | |||||
| read_only_index++; | |||||
| } | |||||
| } | |||||
| return result; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -42,10 +42,10 @@ namespace Tensorflow.Framework | |||||
| func_graph.as_default(); | func_graph.as_default(); | ||||
| importer.import_graph_def(graph_def, name: "", validate_colocation_constraints: false); | importer.import_graph_def(graph_def, name: "", validate_colocation_constraints: false); | ||||
| var input_tensor_names = fdef.Signature.InputArg.Select(x => nested_to_flat_tensor_name[x.Name]); | var input_tensor_names = fdef.Signature.InputArg.Select(x => nested_to_flat_tensor_name[x.Name]); | ||||
| func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); | |||||
| func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x)).ToArray()); | |||||
| var output_tensor_names = fdef.Signature.OutputArg.Select(x => nested_to_flat_tensor_name[fdef.Ret[x.Name]]); | var output_tensor_names = fdef.Signature.OutputArg.Select(x => nested_to_flat_tensor_name[fdef.Ret[x.Name]]); | ||||
| func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); | |||||
| func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x)).ToArray()); | |||||
| // TODO(Rinne): func_graph.ControlOutputs | // TODO(Rinne): func_graph.ControlOutputs | ||||
| _set_handle_data(func_graph, fdef); | _set_handle_data(func_graph, fdef); | ||||
| @@ -8,6 +8,7 @@ using Tensorflow.Gradients; | |||||
| using Tensorflow.Graphs; | using Tensorflow.Graphs; | ||||
| using Tensorflow.Train; | using Tensorflow.Train; | ||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| using Tensorflow.Common.Extensions; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Functions | namespace Tensorflow.Functions | ||||
| @@ -40,6 +41,18 @@ namespace Tensorflow.Functions | |||||
| public Tensor[] FlatStructuredOutputs => func_graph.FlatStructuredOutputs; | public Tensor[] FlatStructuredOutputs => func_graph.FlatStructuredOutputs; | ||||
| public IEnumerable<IVariableV1> Variables => func_graph.Variables; | public IEnumerable<IVariableV1> Variables => func_graph.Variables; | ||||
| public IEnumerable<IVariableV1> TrainableVariables => func_graph.TrainableVariables; | public IEnumerable<IVariableV1> TrainableVariables => func_graph.TrainableVariables; | ||||
| internal NameAttrList AsNameAttrList | |||||
| { | |||||
| get | |||||
| { | |||||
| NameAttrList ret = new() { Name = this.Name }; | |||||
| foreach (var (name, value) in _attrs) | |||||
| { | |||||
| ret.Attr[name] = value; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| public ConcreteFunction(string name) | public ConcreteFunction(string name) | ||||
| { | { | ||||
| @@ -3,4 +3,7 @@ global using System.Collections.Generic; | |||||
| global using System.Text; | global using System.Text; | ||||
| global using System.Collections; | global using System.Collections; | ||||
| global using System.Data; | global using System.Data; | ||||
| global using System.Linq; | |||||
| global using System.Linq; | |||||
| global using Tensorflow.Keras.Engine; | |||||
| global using Tensorflow.Framework.Models; | |||||
| global using static Tensorflow.Binding; | |||||
| @@ -90,8 +90,7 @@ namespace Tensorflow.Gradients | |||||
| ? input_values[0].rank + dim_int | ? input_values[0].rank + dim_int | ||||
| : dim_int % input_values[0].rank; | : dim_int % input_values[0].rank; | ||||
| var sizes = input_values.Select(x => x.shape[non_neg_concat_dim]).ToArray(); | var sizes = input_values.Select(x => x.shape[non_neg_concat_dim]).ToArray(); | ||||
| var sizes_tensor = constant_op.constant(sizes); | |||||
| out_grads = array_ops.split(grad, sizes_tensor, non_neg_concat_dim).ToList(); | |||||
| out_grads = array_ops.split(grad, sizes.Select(x => (int)x).ToArray(), ops.convert_to_tensor(non_neg_concat_dim)).ToList(); | |||||
| } | } | ||||
| else if (constant_op.is_constant(concat_dim)) | else if (constant_op.is_constant(concat_dim)) | ||||
| { | { | ||||
| @@ -127,7 +126,7 @@ namespace Tensorflow.Gradients | |||||
| new Tensor[] { non_neg_concat_dim, tf.constant(0) }, | new Tensor[] { non_neg_concat_dim, tf.constant(0) }, | ||||
| new Tensor[] { tf.constant(1), tf.constant(-1) }); | new Tensor[] { tf.constant(1), tf.constant(-1) }); | ||||
| var squeeze_sizes = array_ops.squeeze(slice); | var squeeze_sizes = array_ops.squeeze(slice); | ||||
| out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_split: (int)non_neg_concat_dim).ToList(); | |||||
| out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_or_size_splits: (int)non_neg_concat_dim).ToList(); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -374,5 +373,13 @@ namespace Tensorflow.Gradients | |||||
| var p = op.inputs[1]; | var p = op.inputs[1]; | ||||
| return new Tensor[] { array_ops.transpose(grads[0], array_ops.invert_permutation(p)), null }; | return new Tensor[] { array_ops.transpose(grads[0], array_ops.invert_permutation(p)), null }; | ||||
| } | } | ||||
| [RegisterGradient("ReverseV2")] | |||||
| public static Tensor[] _ReverseV2Grad(Operation op, Tensor[] grads) | |||||
| { | |||||
| var grad = grads[0]; | |||||
| var axis = op.inputs[1]; | |||||
| return new Tensor[] { array_ops.reverse(grad, axis), null }; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -117,6 +117,137 @@ namespace Tensorflow.Gradients | |||||
| }; | }; | ||||
| } | } | ||||
| public static string ellipsis = "..."; | |||||
| [RegisterGradient("Einsum")] | |||||
| public static Tensor[] _EinsumGrad(Operation op, Tensor[] grads) | |||||
| { | |||||
| // Gradient for Einsum. | |||||
| string equation = (string)op.get_attr("equation"); | |||||
| string[] split_equation = equation.Split(new string[] { "->" }, StringSplitOptions.None); | |||||
| var input_subs = split_equation[0]; | |||||
| var output_subs = split_equation[1]; | |||||
| if (op.inputs.Length == 1) | |||||
| { | |||||
| var input_shape = array_ops.shape(op.inputs[0]); | |||||
| var reduced_label_set = new HashSet<char>(new HashSet<char>(input_subs).Except(new HashSet<char>(output_subs + ellipsis))); | |||||
| if (reduced_label_set.Count == 0) | |||||
| return new Tensor[] { math_ops.einsum(string.Format("{0}->{1}", output_subs, input_subs), new Tensors(grads)) }; | |||||
| return new Tensor[] { _GetGradReduced(new Tensors(grads), output_subs, input_subs, input_shape, reduced_label_set) }; | |||||
| } | |||||
| string[] split_input_subs = input_subs.Split(new string[] { "," }, StringSplitOptions.None); | |||||
| var x_subs = split_input_subs[0]; | |||||
| var y_subs = split_input_subs[1]; | |||||
| // Add ellipsis for broadcasted dimensions if any operand does not have it. | |||||
| // This is because the equation "...ij,jk->ik" may be valid if the 0th input's | |||||
| // batch shape is empty, but the VJP equation "jk,ik->...ij" is not valid | |||||
| // because only the output subscripts contain ellipsis. | |||||
| if (output_subs.Contains(ellipsis)) | |||||
| { | |||||
| if (!x_subs.Contains(ellipsis)) | |||||
| x_subs += ellipsis; | |||||
| if (!y_subs.Contains(ellipsis)) | |||||
| y_subs += ellipsis; | |||||
| } | |||||
| // Obtain the gradients wrt the inputs x and y, without taking into account | |||||
| // the unbroadcasting. | |||||
| var x = op.inputs[0]; | |||||
| var y = op.inputs[1]; | |||||
| if (grads.GetDataType().is_complex()) | |||||
| { | |||||
| x = math_ops.conj(x); | |||||
| y = math_ops.conj(y); | |||||
| } | |||||
| var x_shape = array_ops.shape(x); | |||||
| var y_shape = array_ops.shape(y); | |||||
| var grad_x = _GetGradWrt(grads, y, x_shape, x_subs, y_subs, output_subs); | |||||
| var grad_y = _GetGradWrt(grads, x, y_shape, y_subs, x_subs, output_subs); | |||||
| if (!output_subs.Contains(ellipsis)) | |||||
| return new Tensor[] { grad_x, grad_y }; | |||||
| var bx = _GetBcastSubshape(x_subs); | |||||
| int bx_start = bx[0], bx_end = bx[1]; | |||||
| var by = _GetBcastSubshape(y_subs); | |||||
| int by_start = by[0], by_end = by[1]; | |||||
| var x_shape_static = x.shape; | |||||
| var y_shape_static = y.shape; | |||||
| if(x_shape_static.IsFullyDefined && | |||||
| y_shape_static.IsFullyDefined && | |||||
| x_shape_static[string.Format("{0}:{1}",bx_start,bx_end)] == y_shape_static[string.Format("{0}:{1}", by_start, by_end)]) | |||||
| return new Tensor[] { grad_x, grad_y }; | |||||
| var r = gen_array_ops.broadcast_gradient_args(x_shape[string.Format("{0}:{1}", bx_start, bx_end)], | |||||
| y_shape[string.Format("{0}:{1}", by_start, by_end)]); | |||||
| var rx = r[0]; | |||||
| var ry = r[1]; | |||||
| grad_x = array_ops.reshape(math_ops.reduce_sum(grad_x, bx_start + rx), x_shape); | |||||
| grad_y = array_ops.reshape(math_ops.reduce_sum(grad_y, by_start + ry), y_shape); | |||||
| return new Tensor[] { grad_x, grad_y }; | |||||
| } | |||||
| protected static Tensor _GetGradWrt(Tensor[] output_grads, Tensor other_operand, Tensor input_shape, | |||||
| string input_subs, string other_subs, string output_subs) | |||||
| { | |||||
| var reduced_label_set = new HashSet<char>(new HashSet<char>(input_subs).Except(new HashSet<char>(output_subs + other_subs + "."))); | |||||
| var left_subs = string.Join("", input_subs.Where(s => !reduced_label_set.Contains(s))); | |||||
| var grad_reduced = math_ops.einsum(string.Format("{0},{1}->{2}", output_subs, other_subs, left_subs), new Tensors((Tensors)output_grads, other_operand)); | |||||
| if (reduced_label_set.Count == 0) | |||||
| return grad_reduced; | |||||
| return _GetGradReduced(grad_reduced, left_subs, input_subs, input_shape, reduced_label_set); | |||||
| } | |||||
| protected static Tensor _GetGradReduced(Tensor output_grad, string output_subs, string input_subs, Tensor input_shape, HashSet<char> reduced_label_set) | |||||
| { | |||||
| string reduced_subs; | |||||
| Tensor reduced_dims; | |||||
| List<int> reduced_axes; | |||||
| _GetReducedSubscripts(reduced_label_set, input_shape, input_subs, out reduced_subs, out reduced_dims, out reduced_axes); | |||||
| bool has_repeated_labels = ( | |||||
| new HashSet<char>(input_subs).Count + new HashSet<char>(output_subs).Count < | |||||
| input_subs.Length + output_subs.Length); | |||||
| var input_subs_without_reduced_labels = string.Join("", input_subs.Where(s => !reduced_label_set.Contains(s))); | |||||
| if (!has_repeated_labels && input_subs_without_reduced_labels == output_subs) | |||||
| { | |||||
| var reduced_shape = math_ops.reduced_shape(input_shape, ops.convert_to_tensor(reduced_axes)); | |||||
| return gen_array_ops.broadcast_to(array_ops.reshape(output_grad, reduced_shape), input_shape); | |||||
| } | |||||
| else | |||||
| { | |||||
| var grad_shape_with_reduced_labels = array_ops.concat(new Tensor[] { reduced_dims, array_ops.shape(new Tensors(output_grad)) }, axis: 0); | |||||
| var reduced_shape = array_ops.concat(new Tensor[] { array_ops.ones(reduced_label_set.Count, dtype: dtypes.int32), array_ops.shape(new Tensors(output_grad)) }, axis: 0); | |||||
| var broadcasted_grad = gen_array_ops.broadcast_to(array_ops.reshape(output_grad, reduced_shape), grad_shape_with_reduced_labels); | |||||
| return math_ops.einsum(string.Format("{0}->{1}", reduced_subs + output_subs, input_subs), new Tensors(broadcasted_grad)); | |||||
| } | |||||
| } | |||||
| protected static void _GetReducedSubscripts(HashSet<char> reduced_label_set, Tensor input_shape, string subscripts, out string reduced_subs, out Tensor reduced_dims, out List<int> reduced_axes) | |||||
| { | |||||
| reduced_subs = string.Join("", reduced_label_set.Select(c => c.ToString())); | |||||
| reduced_axes = reduced_subs.Select(s => _GetAxisFromLabel(subscripts, s)).ToList(); | |||||
| reduced_dims = array_ops.stack(reduced_axes.Select(ax => input_shape[ax]).ToList()); | |||||
| } | |||||
| protected static int _GetAxisFromLabel(string subscripts, char label) | |||||
| { | |||||
| var splits = subscripts.Split(new string[] { ellipsis }, StringSplitOptions.None); | |||||
| var index = splits[0].IndexOf(label); | |||||
| if (index != -1) return index; | |||||
| if (splits.Length < 2) throw new OutOfRangeError(); | |||||
| index = splits[1].IndexOf(label); | |||||
| if (index != -1) return index; | |||||
| throw new ValueError(); | |||||
| } | |||||
| protected static int[] _GetBcastSubshape(string subscripts) | |||||
| { | |||||
| int start = subscripts.IndexOf(ellipsis); | |||||
| if (start == -1) return new int[] { 0, 0 }; | |||||
| int remaining = subscripts.Length - (start + ellipsis.Length); | |||||
| int end; | |||||
| if (remaining > 0) end = remaining; | |||||
| else throw new Exception(); | |||||
| return new int[] { start, end }; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns grad * exp(x). | /// Returns grad * exp(x). | ||||
| /// </summary> | /// </summary> | ||||
| @@ -365,6 +365,23 @@ namespace Tensorflow.Gradients | |||||
| }; | }; | ||||
| } | } | ||||
| [RegisterGradient("AvgPool")] | |||||
| public static Tensor[] _AvgPoolGrad(Operation op, Tensor[] grads) | |||||
| { | |||||
| Tensor grad = grads[0]; | |||||
| return new Tensor[] | |||||
| { | |||||
| gen_nn_ops.avg_pool_grad( | |||||
| array_ops.shape(op.inputs[0]), | |||||
| grad, | |||||
| op.get_attr_list<int>("ksize"), | |||||
| op.get_attr_list<int>("strides"), | |||||
| op.get_attr<string>("padding"), | |||||
| op.get_attr<string>("data_format")) | |||||
| }; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Return the gradients for TopK. | /// Return the gradients for TopK. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -81,7 +81,7 @@ public class FuncGraph : Graph, IDisposable | |||||
| public IEnumerable<IVariableV1> TrainableVariables => Variables.Where(v => v.Trainable); | public IEnumerable<IVariableV1> TrainableVariables => Variables.Where(v => v.Trainable); | ||||
| public Dictionary<string, AttrValue> Attrs { get; set; } | public Dictionary<string, AttrValue> Attrs { get; set; } | ||||
| Dictionary<long, (Tensor, Tensor)> _captures | |||||
| internal Dictionary<long, (Tensor, Tensor)> _captures | |||||
| = new Dictionary<long, (Tensor, Tensor)>(); | = new Dictionary<long, (Tensor, Tensor)>(); | ||||
| public Tensor[] external_captures | public Tensor[] external_captures | ||||
| @@ -399,7 +399,7 @@ public class FuncGraph : Graph, IDisposable | |||||
| var flat_func_args = nest.flatten(func_args as object); | var flat_func_args = nest.flatten(func_args as object); | ||||
| var flat_func_kwargs = nest.flatten(func_kwargs as object); | var flat_func_kwargs = nest.flatten(func_kwargs as object); | ||||
| func_graph.Inputs = new Tensors(flat_func_args.concat(flat_func_kwargs) | func_graph.Inputs = new Tensors(flat_func_args.concat(flat_func_kwargs) | ||||
| .Where(x => x is Tensor).Select(x => (Tensor)x)); | |||||
| .Where(x => x is Tensor).Select(x => (Tensor)x).ToArray()); | |||||
| //var func_args_before = nest.pack_sequence_as(func_args, flat_func_args, true); | //var func_args_before = nest.pack_sequence_as(func_args, flat_func_args, true); | ||||
| //var func_kwargs_before = nest.pack_sequence_as(func_kwargs, flat_func_kwargs, true); | //var func_kwargs_before = nest.pack_sequence_as(func_kwargs, flat_func_kwargs, true); | ||||
| @@ -544,12 +544,12 @@ public class FuncGraph : Graph, IDisposable | |||||
| Tensor placeholder; | Tensor placeholder; | ||||
| try | try | ||||
| { | { | ||||
| placeholder = tf.placeholder(tensor.dtype, tensor.shape, name); | |||||
| placeholder = GraphOnlyOps.graph_placeholder(tensor.dtype, tensor.shape, name); | |||||
| } | } | ||||
| catch (ValueError) | |||||
| catch (ValueError ex) | |||||
| { | { | ||||
| // TODO(Rinne): Add warning here. | |||||
| placeholder = tf.placeholder(tensor.dtype, tensor.shape); | |||||
| tf.Logger.Warning(ex.ToString()); | |||||
| placeholder = GraphOnlyOps.graph_placeholder(tensor.dtype, tensor.shape); | |||||
| } | } | ||||
| handle_data_util.copy_handle_data(tensor, placeholder); | handle_data_util.copy_handle_data(tensor, placeholder); | ||||
| if (name is not null) | if (name is not null) | ||||
| @@ -575,12 +575,12 @@ public class FuncGraph : Graph, IDisposable | |||||
| Tensor placeholder; | Tensor placeholder; | ||||
| try | try | ||||
| { | { | ||||
| placeholder = tf.placeholder(spec.dtype, spec.shape, requested_name); | |||||
| placeholder = GraphOnlyOps.graph_placeholder(spec.dtype, spec.shape, requested_name); | |||||
| } | } | ||||
| catch (ValueError) | catch (ValueError) | ||||
| { | { | ||||
| // TODO(Rinne): Add warning here. | // TODO(Rinne): Add warning here. | ||||
| placeholder = tf.placeholder(spec.dtype, spec.shape); | |||||
| placeholder = GraphOnlyOps.graph_placeholder(spec.dtype, spec.shape); | |||||
| } | } | ||||
| if (name is not null) | if (name is not null) | ||||
| { | { | ||||
| @@ -129,7 +129,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| protected Graph outer_graph; | |||||
| internal Graph outer_graph; | |||||
| public Graph OuterGraph => outer_graph; | public Graph OuterGraph => outer_graph; | ||||
| public Dictionary<string, EagerDefinedFunction> Functions => _functions; | public Dictionary<string, EagerDefinedFunction> Functions => _functions; | ||||
| public SafeGraphHandle c_graph => _handle; | public SafeGraphHandle c_graph => _handle; | ||||
| @@ -0,0 +1,10 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class ExponentialArgs : LayerArgs | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,10 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class HardSigmoidArgs : LayerArgs | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,11 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class SELUArgs : LayerArgs | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,10 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class SoftplusArgs : LayerArgs | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,10 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class SoftsignArgs : LayerArgs | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,10 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class SwishArgs : LayerArgs | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,10 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class TanhArgs : LayerArgs | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,10 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class Conv2DTransposeArgs : Conv2DArgs | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,10 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class AddArgs : MergeArgs | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,10 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class ConcatenateArgs : MergeArgs | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,10 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class SubtractArgs : MergeArgs | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,10 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class GlobalAveragePooling1DArgs : Pooling1DArgs | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,10 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class GlobalAveragePooling2DArgs : Pooling2DArgs | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,10 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class GlobalMaxPooling1DArgs : Pooling1DArgs | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,10 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class GlobalMaxPooling2DArgs : Pooling2DArgs | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,10 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class MaxPooling1DArgs : Pooling1DArgs | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -7,7 +7,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
| [JsonProperty("size")] | [JsonProperty("size")] | ||||
| public Shape Size { get; set; } | public Shape Size { get; set; } | ||||
| [JsonProperty("data_format")] | [JsonProperty("data_format")] | ||||
| public string DataFormat { get; set; } | |||||
| public string DataFormat { get; set; } = "channels_last"; | |||||
| /// <summary> | /// <summary> | ||||
| /// 'nearest', 'bilinear' | /// 'nearest', 'bilinear' | ||||
| /// </summary> | /// </summary> | ||||
| @@ -0,0 +1,10 @@ | |||||
| using Newtonsoft.Json; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class UpSampling1DArgs : AutoSerializeLayerArgs | |||||
| { | |||||
| [JsonProperty("size")] | |||||
| public int Size { get; set; } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,20 @@ | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.NumPy; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class BidirectionalArgs : AutoSerializeLayerArgs | |||||
| { | |||||
| [JsonProperty("layer")] | |||||
| public ILayer Layer { get; set; } | |||||
| [JsonProperty("merge_mode")] | |||||
| public string? MergeMode { get; set; } | |||||
| [JsonProperty("backward_layer")] | |||||
| public ILayer BackwardLayer { get; set; } | |||||
| public NDArray Weights { get; set; } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,29 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class GRUArgs : AutoSerializeLayerArgs | |||||
| { | |||||
| public int Units { get; set; } | |||||
| public Activation Activation { get; set; } | |||||
| public Activation RecurrentActivation { get; set; } | |||||
| public bool UseBias { get; set; } = true; | |||||
| public float Dropout { get; set; } = .0f; | |||||
| public float RecurrentDropout { get; set; } = .0f; | |||||
| public IInitializer KernelInitializer { get; set; } | |||||
| public IInitializer RecurrentInitializer { get; set; } | |||||
| public IInitializer BiasInitializer { get; set; } | |||||
| public bool ReturnSequences { get;set; } | |||||
| public bool ReturnState { get;set; } | |||||
| public bool GoBackwards { get;set; } | |||||
| public bool Stateful { get;set; } | |||||
| public bool Unroll { get;set; } | |||||
| public bool TimeMajor { get;set; } | |||||
| public bool ResetAfter { get;set; } | |||||
| public int Implementation { get; set; } = 2; | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,39 @@ | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class GRUCellArgs : AutoSerializeLayerArgs | |||||
| { | |||||
| [JsonProperty("units")] | |||||
| public int Units { get; set; } | |||||
| // TODO(Rinne): lack of initialized value of Activation. Merging keras | |||||
| // into tf.net could resolve it. | |||||
| [JsonProperty("activation")] | |||||
| public Activation Activation { get; set; } | |||||
| [JsonProperty("recurrent_activation")] | |||||
| public Activation RecurrentActivation { get; set; } | |||||
| [JsonProperty("use_bias")] | |||||
| public bool UseBias { get; set; } = true; | |||||
| [JsonProperty("dropout")] | |||||
| public float Dropout { get; set; } = .0f; | |||||
| [JsonProperty("recurrent_dropout")] | |||||
| public float RecurrentDropout { get; set; } = .0f; | |||||
| [JsonProperty("kernel_initializer")] | |||||
| public IInitializer KernelInitializer { get; set; } | |||||
| [JsonProperty("recurrent_initializer")] | |||||
| public IInitializer RecurrentInitializer { get; set; } | |||||
| [JsonProperty("bias_initializer")] | |||||
| public IInitializer BiasInitializer { get; set; } | |||||
| [JsonProperty("reset_after")] | |||||
| public bool ResetAfter { get;set; } | |||||
| [JsonProperty("implementation")] | |||||
| public int Implementation { get; set; } = 2; | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,13 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class GRUOptionalArgs | |||||
| { | |||||
| public string Identifier => "GRU"; | |||||
| public Tensor Mask { get; set; } = null; | |||||
| } | |||||
| } | |||||
| @@ -1,11 +1,14 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | { | ||||
| public class LSTMArgs : RNNArgs | public class LSTMArgs : RNNArgs | ||||
| { | { | ||||
| // TODO: maybe change the `RNNArgs` and implement this class. | // TODO: maybe change the `RNNArgs` and implement this class. | ||||
| public bool UnitForgetBias { get; set; } | public bool UnitForgetBias { get; set; } | ||||
| public float Dropout { get; set; } | |||||
| public float RecurrentDropout { get; set; } | |||||
| public int Implementation { get; set; } | public int Implementation { get; set; } | ||||
| public LSTMArgs Clone() | |||||
| { | |||||
| return (LSTMArgs)MemberwiseClone(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,7 +1,35 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
| using Newtonsoft.Json; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | { | ||||
| // TODO: complete the implementation | // TODO: complete the implementation | ||||
| public class LSTMCellArgs : LayerArgs | |||||
| public class LSTMCellArgs : AutoSerializeLayerArgs | |||||
| { | { | ||||
| [JsonProperty("units")] | |||||
| public int Units { get; set; } | |||||
| // TODO(Rinne): lack of initialized value of Activation. Merging keras | |||||
| // into tf.net could resolve it. | |||||
| [JsonProperty("activation")] | |||||
| public Activation Activation { get; set; } | |||||
| [JsonProperty("recurrent_activation")] | |||||
| public Activation RecurrentActivation { get; set; } | |||||
| [JsonProperty("use_bias")] | |||||
| public bool UseBias { get; set; } = true; | |||||
| [JsonProperty("dropout")] | |||||
| public float Dropout { get; set; } = .0f; | |||||
| [JsonProperty("recurrent_dropout")] | |||||
| public float RecurrentDropout { get; set; } = .0f; | |||||
| [JsonProperty("kernel_initializer")] | |||||
| public IInitializer KernelInitializer { get; set; } | |||||
| [JsonProperty("recurrent_initializer")] | |||||
| public IInitializer RecurrentInitializer { get; set; } | |||||
| [JsonProperty("bias_initializer")] | |||||
| public IInitializer BiasInitializer { get; set; } | |||||
| [JsonProperty("unit_forget_bias")] | |||||
| public bool UnitForgetBias { get; set; } = true; | |||||
| [JsonProperty("implementation")] | |||||
| public int Implementation { get; set; } = 2; | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,17 +1,12 @@ | |||||
| using Newtonsoft.Json; | using Newtonsoft.Json; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using Tensorflow.Keras.Layers; | |||||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | { | ||||
| // TODO(Rinne): add regularizers. | |||||
| public class RNNArgs : AutoSerializeLayerArgs | public class RNNArgs : AutoSerializeLayerArgs | ||||
| { | { | ||||
| public interface IRnnArgCell : ILayer | |||||
| { | |||||
| object state_size { get; } | |||||
| } | |||||
| [JsonProperty("cell")] | |||||
| // TODO: the cell should be serialized with `serialize_keras_object`. | |||||
| public IRnnArgCell Cell { get; set; } = null; | |||||
| [JsonProperty("return_sequences")] | [JsonProperty("return_sequences")] | ||||
| public bool ReturnSequences { get; set; } = false; | public bool ReturnSequences { get; set; } = false; | ||||
| [JsonProperty("return_state")] | [JsonProperty("return_state")] | ||||
| @@ -24,31 +19,31 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
| public bool Unroll { get; set; } = false; | public bool Unroll { get; set; } = false; | ||||
| [JsonProperty("time_major")] | [JsonProperty("time_major")] | ||||
| public bool TimeMajor { get; set; } = false; | public bool TimeMajor { get; set; } = false; | ||||
| // TODO: Add `num_constants` and `zero_output_for_mask`. | |||||
| public Dictionary<string, object> Kwargs { get; set; } = null; | |||||
| public int? InputDim { get; set; } | |||||
| public int? InputLength { get; set; } | |||||
| // TODO: Add `num_constants` and `zero_output_for_mask`. | |||||
| [JsonProperty("units")] | |||||
| public int Units { get; set; } | public int Units { get; set; } | ||||
| [JsonProperty("activation")] | |||||
| public Activation Activation { get; set; } | public Activation Activation { get; set; } | ||||
| [JsonProperty("recurrent_activation")] | |||||
| public Activation RecurrentActivation { get; set; } | public Activation RecurrentActivation { get; set; } | ||||
| [JsonProperty("use_bias")] | |||||
| public bool UseBias { get; set; } = true; | public bool UseBias { get; set; } = true; | ||||
| public IInitializer KernelInitializer { get; set; } | public IInitializer KernelInitializer { get; set; } | ||||
| public IInitializer RecurrentInitializer { get; set; } | public IInitializer RecurrentInitializer { get; set; } | ||||
| public IInitializer BiasInitializer { get; set; } | public IInitializer BiasInitializer { get; set; } | ||||
| [JsonProperty("dropout")] | |||||
| public float Dropout { get; set; } = .0f; | |||||
| [JsonProperty("zero_output_for_mask")] | |||||
| public bool ZeroOutputForMask { get; set; } = false; | |||||
| [JsonProperty("recurrent_dropout")] | |||||
| public float RecurrentDropout { get; set; } = .0f; | |||||
| // kernel_regularizer=None, | |||||
| // recurrent_regularizer=None, | |||||
| // bias_regularizer=None, | |||||
| // activity_regularizer=None, | |||||
| // kernel_constraint=None, | |||||
| // recurrent_constraint=None, | |||||
| // bias_constraint=None, | |||||
| // dropout=0., | |||||
| // recurrent_dropout=0., | |||||
| // return_sequences=False, | |||||
| // return_state=False, | |||||
| // go_backwards=False, | |||||
| // stateful=False, | |||||
| // unroll=False, | |||||
| // **kwargs): | |||||
| public RNNArgs Clone() | |||||
| { | |||||
| return (RNNArgs)MemberwiseClone(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,14 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Common.Types; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class RnnOptionalArgs: IOptionalArgs | |||||
| { | |||||
| public string Identifier => "Rnn"; | |||||
| public Tensor Mask { get; set; } = null; | |||||
| public Tensors Constants { get; set; } = null; | |||||
| } | |||||
| } | |||||
| @@ -1,4 +1,4 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | { | ||||
| public class SimpleRNNArgs : RNNArgs | public class SimpleRNNArgs : RNNArgs | ||||
| { | { | ||||
| @@ -0,0 +1,27 @@ | |||||
| using Newtonsoft.Json; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class SimpleRNNCellArgs: AutoSerializeLayerArgs | |||||
| { | |||||
| [JsonProperty("units")] | |||||
| public int Units { get; set; } | |||||
| // TODO(Rinne): lack of initialized value of Activation. Merging keras | |||||
| // into tf.net could resolve it. | |||||
| [JsonProperty("activation")] | |||||
| public Activation Activation { get; set; } | |||||
| [JsonProperty("use_bias")] | |||||
| public bool UseBias { get; set; } = true; | |||||
| [JsonProperty("dropout")] | |||||
| public float Dropout { get; set; } = .0f; | |||||
| [JsonProperty("recurrent_dropout")] | |||||
| public float RecurrentDropout { get; set; } = .0f; | |||||
| [JsonProperty("kernel_initializer")] | |||||
| public IInitializer KernelInitializer { get; set; } | |||||
| [JsonProperty("recurrent_initializer")] | |||||
| public IInitializer RecurrentInitializer { get; set; } | |||||
| [JsonProperty("bias_initializer")] | |||||
| public IInitializer BiasInitializer { get; set; } | |||||
| } | |||||
| } | |||||
| @@ -1,10 +1,10 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using Tensorflow.Keras.Layers; | |||||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | { | ||||
| public class StackedRNNCellsArgs : LayerArgs | public class StackedRNNCellsArgs : LayerArgs | ||||
| { | { | ||||
| public IList<RnnCell> Cells { get; set; } | |||||
| public Dictionary<string, object> Kwargs { get; set; } = null; | |||||
| public bool ReverseStateOrder = false; | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,24 @@ | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Runtime.CompilerServices; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class WrapperArgs : AutoSerializeLayerArgs | |||||
| { | |||||
| [JsonProperty("layer")] | |||||
| public ILayer Layer { get; set; } | |||||
| public WrapperArgs(ILayer layer) | |||||
| { | |||||
| Layer = layer; | |||||
| } | |||||
| public static implicit operator WrapperArgs(BidirectionalArgs args) | |||||
| => new WrapperArgs(args.Layer); | |||||
| } | |||||
| } | |||||
| @@ -14,6 +14,9 @@ public interface ICallback | |||||
| void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs); | void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs); | ||||
| void on_predict_end(); | void on_predict_end(); | ||||
| void on_test_begin(); | void on_test_begin(); | ||||
| void on_test_end(Dictionary<string, float> logs); | |||||
| void on_test_batch_begin(long step); | void on_test_batch_begin(long step); | ||||
| void on_test_batch_end(long end_step, Dictionary<string, float> logs); | void on_test_batch_end(long end_step, Dictionary<string, float> logs); | ||||
| } | } | ||||
| @@ -60,7 +60,7 @@ public interface IModel : ILayer | |||||
| bool skip_mismatch = false, | bool skip_mismatch = false, | ||||
| object options = null); | object options = null); | ||||
| Dictionary<string, float> evaluate(Tensor x, Tensor y, | |||||
| Dictionary<string, float> evaluate(NDArray x, NDArray y, | |||||
| int batch_size = -1, | int batch_size = -1, | ||||
| int verbose = 1, | int verbose = 1, | ||||
| int steps = -1, | int steps = -1, | ||||
| @@ -0,0 +1,75 @@ | |||||
| namespace Tensorflow.Keras.Engine; | |||||
| /// <summary> | |||||
| /// A representation of a Keras in/output during Functional API construction. | |||||
| /// </summary> | |||||
| public class KerasTensor | |||||
| { | |||||
| private Tensors _original_tensors; | |||||
| public Tensors original_tensors | |||||
| { | |||||
| get => _original_tensors; | |||||
| set => _original_tensors = value; | |||||
| } | |||||
| private Shape _inferred_value; | |||||
| public Shape inferred_value => _inferred_value; | |||||
| private string _name; | |||||
| private TensorSpec _type_spec; | |||||
| public Shape shape => _type_spec.shape; | |||||
| public TF_DataType dtype => _type_spec.dtype; | |||||
| public KerasTensor(TensorSpec type_spec, Shape inferred_value = null, string name = null) | |||||
| { | |||||
| _type_spec = type_spec; | |||||
| _inferred_value = inferred_value; | |||||
| _name = name; | |||||
| } | |||||
| public static KerasTensor from_tensor(Tensor tensor) | |||||
| { | |||||
| var type_spec = tensor.ToTensorSpec(); | |||||
| Shape? inferred_value = default; | |||||
| if (tensor.dtype == TF_DataType.TF_INT32 && tensor.rank < 2) | |||||
| { | |||||
| inferred_value = tf.ones(tensor).shape; | |||||
| } | |||||
| var kt = new KerasTensor(type_spec, inferred_value: inferred_value, name: tensor.name); | |||||
| kt.original_tensors = tensor; | |||||
| return kt; | |||||
| } | |||||
| public KerasTensor this[int idx] | |||||
| => _original_tensors.First()[idx]; | |||||
| public KerasTensor this[params Slice[] slices] | |||||
| => _original_tensors.First()[slices]; | |||||
| public override string ToString() | |||||
| => _original_tensors.Length switch | |||||
| { | |||||
| > 1 => "[" + string.Join(", ", _original_tensors.Select(x => $"KerasTensor: shape={x.shape} dtype={x.dtype.as_numpy_name()}{GetInferredValueString()}")) + "]", | |||||
| 1 => $"KerasTensor: shape={_original_tensors.shape} dtype={_original_tensors.dtype.as_numpy_name()}{GetInferredValueString()}", | |||||
| _ => _original_tensors.ToString(), | |||||
| }; | |||||
| private string GetInferredValueString() | |||||
| => _inferred_value == null ? "" : $" inferred_value={_inferred_value}"; | |||||
| public static implicit operator Tensors(KerasTensor kt) | |||||
| => kt._original_tensors; | |||||
| public static implicit operator Tensor(KerasTensor kt) | |||||
| { | |||||
| Tensor tensor = kt._original_tensors; | |||||
| tensor.IsFromKerasTensor = true; | |||||
| return tensor; | |||||
| } | |||||
| public static implicit operator KerasTensor(Tensor tensor) | |||||
| => from_tensor(tensor); | |||||
| public static implicit operator KerasTensor(Tensors tensors) | |||||
| => from_tensor(tensors.First()); | |||||
| } | |||||
| @@ -25,6 +25,27 @@ namespace Tensorflow.Keras | |||||
| bool amsgrad = false, | bool amsgrad = false, | ||||
| string name = "Adam"); | string name = "Adam"); | ||||
| /// <summary> | |||||
| /// Adam enables L2 weight decay on gradients. | |||||
| /// </summary> | |||||
| /// <param name="learning_rate"></param> | |||||
| /// <param name="weight_decay"></param> | |||||
| /// <param name="beta_1"></param> | |||||
| /// <param name="beta_2"></param> | |||||
| /// <param name="epsilon"></param> | |||||
| /// <param name="amsgrad"></param> | |||||
| /// <param name="decay_params"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| IOptimizer AdamW(float learning_rate = 0.001f, | |||||
| float weight_decay = 0.004f, | |||||
| float beta_1 = 0.9f, | |||||
| float beta_2 = 0.999f, | |||||
| float epsilon = 1e-7f, | |||||
| bool amsgrad = false, | |||||
| List<string> no_decay_params = null, | |||||
| string name = "AdamW"); | |||||
| /// <summary> | /// <summary> | ||||
| /// Construct a new RMSprop optimizer. | /// Construct a new RMSprop optimizer. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -42,6 +63,6 @@ namespace Tensorflow.Keras | |||||
| bool centered = false, | bool centered = false, | ||||
| string name = "RMSprop"); | string name = "RMSprop"); | ||||
| IOptimizer SGD(float learning_rate); | |||||
| IOptimizer SGD(float learning_rate = 0.01f, float momentum = 0f); | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,4 +1,5 @@ | |||||
| using Tensorflow.Keras.Engine; | |||||
| using Tensorflow.Common.Types; | |||||
| using Tensorflow.Keras.Engine; | |||||
| using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
| using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
| using Tensorflow.Training; | using Tensorflow.Training; | ||||
| @@ -14,7 +15,7 @@ namespace Tensorflow.Keras | |||||
| List<ILayer> Layers { get; } | List<ILayer> Layers { get; } | ||||
| List<INode> InboundNodes { get; } | List<INode> InboundNodes { get; } | ||||
| List<INode> OutboundNodes { get; } | List<INode> OutboundNodes { get; } | ||||
| Tensors Apply(Tensors inputs, Tensor state = null, bool training = false); | |||||
| Tensors Apply(Tensors inputs, Tensors states = null, bool? training = false, IOptionalArgs? optional_args = null); | |||||
| List<IVariableV1> TrainableVariables { get; } | List<IVariableV1> TrainableVariables { get; } | ||||
| List<IVariableV1> TrainableWeights { get; } | List<IVariableV1> TrainableWeights { get; } | ||||
| List<IVariableV1> NonTrainableWeights { get; } | List<IVariableV1> NonTrainableWeights { get; } | ||||
| @@ -9,6 +9,10 @@ namespace Tensorflow.Keras.Layers | |||||
| public ILayer Reshape(Shape target_shape); | public ILayer Reshape(Shape target_shape); | ||||
| public ILayer Reshape(object[] target_shape); | public ILayer Reshape(object[] target_shape); | ||||
| public ILayer UpSampling1D( | |||||
| int size | |||||
| ); | |||||
| public ILayer UpSampling2D(Shape size = null, | public ILayer UpSampling2D(Shape size = null, | ||||
| string data_format = null, | string data_format = null, | ||||
| string interpolation = "nearest"); | string interpolation = "nearest"); | ||||
| @@ -1,5 +1,7 @@ | |||||
| using System; | using System; | ||||
| using Tensorflow.Framework.Models; | using Tensorflow.Framework.Models; | ||||
| using Tensorflow.Keras.Engine; | |||||
| using Tensorflow.Keras.Layers; | |||||
| using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
| using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; | using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; | ||||
| @@ -134,7 +136,7 @@ namespace Tensorflow.Keras.Layers | |||||
| public ILayer GlobalMaxPooling1D(string data_format = "channels_last"); | public ILayer GlobalMaxPooling1D(string data_format = "channels_last"); | ||||
| public ILayer GlobalMaxPooling2D(string data_format = "channels_last"); | public ILayer GlobalMaxPooling2D(string data_format = "channels_last"); | ||||
| public Tensors Input(Shape shape = null, | |||||
| public KerasTensor Input(Shape shape = null, | |||||
| int batch_size = -1, | int batch_size = -1, | ||||
| string name = null, | string name = null, | ||||
| TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| @@ -159,6 +161,18 @@ namespace Tensorflow.Keras.Layers | |||||
| public ILayer Normalization(Shape? input_shape = null, int? axis = -1, float? mean = null, float? variance = null, bool invert = false); | public ILayer Normalization(Shape? input_shape = null, int? axis = -1, float? mean = null, float? variance = null, bool invert = false); | ||||
| public ILayer LeakyReLU(float alpha = 0.3f); | public ILayer LeakyReLU(float alpha = 0.3f); | ||||
| public IRnnCell LSTMCell(int uints, | |||||
| string activation = "tanh", | |||||
| string recurrent_activation = "sigmoid", | |||||
| bool use_bias = true, | |||||
| string kernel_initializer = "glorot_uniform", | |||||
| string recurrent_initializer = "orthogonal", | |||||
| string bias_initializer = "zeros", | |||||
| bool unit_forget_bias = true, | |||||
| float dropout = 0f, | |||||
| float recurrent_dropout = 0f, | |||||
| int implementation = 2); | |||||
| public ILayer LSTM(int units, | public ILayer LSTM(int units, | ||||
| Activation activation = null, | Activation activation = null, | ||||
| Activation recurrent_activation = null, | Activation recurrent_activation = null, | ||||
| @@ -192,6 +206,19 @@ namespace Tensorflow.Keras.Layers | |||||
| float offset = 0, | float offset = 0, | ||||
| Shape input_shape = null); | Shape input_shape = null); | ||||
| public IRnnCell SimpleRNNCell( | |||||
| int units, | |||||
| string activation = "tanh", | |||||
| bool use_bias = true, | |||||
| string kernel_initializer = "glorot_uniform", | |||||
| string recurrent_initializer = "orthogonal", | |||||
| string bias_initializer = "zeros", | |||||
| float dropout = 0f, | |||||
| float recurrent_dropout = 0f); | |||||
| public IRnnCell StackedRNNCells( | |||||
| IEnumerable<IRnnCell> cells); | |||||
| public ILayer SimpleRNN(int units, | public ILayer SimpleRNN(int units, | ||||
| string activation = "tanh", | string activation = "tanh", | ||||
| string kernel_initializer = "glorot_uniform", | string kernel_initializer = "glorot_uniform", | ||||
| @@ -200,6 +227,69 @@ namespace Tensorflow.Keras.Layers | |||||
| bool return_sequences = false, | bool return_sequences = false, | ||||
| bool return_state = false); | bool return_state = false); | ||||
| public ILayer RNN( | |||||
| IRnnCell cell, | |||||
| bool return_sequences = false, | |||||
| bool return_state = false, | |||||
| bool go_backwards = false, | |||||
| bool stateful = false, | |||||
| bool unroll = false, | |||||
| bool time_major = false | |||||
| ); | |||||
| public ILayer RNN( | |||||
| IEnumerable<IRnnCell> cell, | |||||
| bool return_sequences = false, | |||||
| bool return_state = false, | |||||
| bool go_backwards = false, | |||||
| bool stateful = false, | |||||
| bool unroll = false, | |||||
| bool time_major = false | |||||
| ); | |||||
| public IRnnCell GRUCell( | |||||
| int units, | |||||
| string activation = "tanh", | |||||
| string recurrent_activation = "sigmoid", | |||||
| bool use_bias = true, | |||||
| string kernel_initializer = "glorot_uniform", | |||||
| string recurrent_initializer = "orthogonal", | |||||
| string bias_initializer = "zeros", | |||||
| float dropout = 0f, | |||||
| float recurrent_dropout = 0f, | |||||
| bool reset_after = true); | |||||
| public ILayer GRU( | |||||
| int units, | |||||
| string activation = "tanh", | |||||
| string recurrent_activation = "sigmoid", | |||||
| bool use_bias = true, | |||||
| string kernel_initializer = "glorot_uniform", | |||||
| string recurrent_initializer = "orthogonal", | |||||
| string bias_initializer = "zeros", | |||||
| float dropout = 0f, | |||||
| float recurrent_dropout = 0f, | |||||
| bool return_sequences = false, | |||||
| bool return_state = false, | |||||
| bool go_backwards = false, | |||||
| bool stateful = false, | |||||
| bool unroll = false, | |||||
| bool time_major = false, | |||||
| bool reset_after = true | |||||
| ); | |||||
| /// <summary> | |||||
| /// Bidirectional wrapper for RNNs. | |||||
| /// </summary> | |||||
| /// <param name="layer">`keras.layers.RNN` instance, such as `keras.layers.LSTM` or `keras.layers.GRU`</param> | |||||
| /// automatically.</param> | |||||
| /// <returns></returns> | |||||
| public ILayer Bidirectional( | |||||
| ILayer layer, | |||||
| string merge_mode = "concat", | |||||
| NDArray weights = null, | |||||
| ILayer backward_layer = null); | |||||
| public ILayer Subtract(); | public ILayer Subtract(); | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,25 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Common.Types; | |||||
| namespace Tensorflow.Keras.Layers | |||||
| { | |||||
| public interface IRnnCell: ILayer | |||||
| { | |||||
| /// <summary> | |||||
| /// If the derived class tends to not implement it, please return null. | |||||
| /// </summary> | |||||
| INestStructure<long>? StateSize { get; } | |||||
| /// <summary> | |||||
| /// If the derived class tends to not implement it, please return null. | |||||
| /// </summary> | |||||
| INestStructure<long>? OutputSize { get; } | |||||
| /// <summary> | |||||
| /// Whether the optional RNN args are supported when appying the layer. | |||||
| /// In other words, whether `Apply` is overwrited with process of `RnnOptionalArgs`. | |||||
| /// </summary> | |||||
| bool SupportOptionalArgs { get; } | |||||
| Tensors GetInitialState(Tensors inputs, Tensor batch_size, TF_DataType dtype); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,12 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.Layers | |||||
| { | |||||
| public interface IStackedRnnCells : IRnnCell | |||||
| { | |||||
| int Count { get; } | |||||
| IRnnCell this[int idx] { get; } | |||||
| } | |||||
| } | |||||
| @@ -3,6 +3,7 @@ using Newtonsoft.Json; | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Common.Types; | |||||
| namespace Tensorflow.Keras.Saving.Json | namespace Tensorflow.Keras.Saving.Json | ||||
| { | { | ||||
| @@ -6,6 +6,7 @@ using System.Text; | |||||
| using System.Diagnostics; | using System.Diagnostics; | ||||
| using OneOf.Types; | using OneOf.Types; | ||||
| using Tensorflow.Keras.Saving.Json; | using Tensorflow.Keras.Saving.Json; | ||||
| using Tensorflow.Common.Types; | |||||
| namespace Tensorflow.Keras.Saving | namespace Tensorflow.Keras.Saving | ||||
| { | { | ||||
| @@ -74,8 +74,3 @@ namespace Tensorflow | |||||
| => IsScalar ? $"{axis[0]}" : $"({string.Join(", ", axis)})"; | => IsScalar ? $"{axis[0]}" : $"({string.Join(", ", axis)})"; | ||||
| } | } | ||||
| } | } | ||||
| namespace System.Runtime.CompilerServices | |||||
| { | |||||
| internal static class IsExternalInit { } | |||||
| } | |||||
| @@ -107,9 +107,15 @@ namespace Tensorflow.NumPy | |||||
| public static implicit operator NDArray(bool value) | public static implicit operator NDArray(bool value) | ||||
| => new NDArray(value); | => new NDArray(value); | ||||
| public static implicit operator NDArray(byte value) | |||||
| => new NDArray(value); | |||||
| public static implicit operator NDArray(int value) | public static implicit operator NDArray(int value) | ||||
| => new NDArray(value); | => new NDArray(value); | ||||
| public static implicit operator NDArray(long value) | |||||
| => new NDArray(value); | |||||
| public static implicit operator NDArray(float value) | public static implicit operator NDArray(float value) | ||||
| => new NDArray(value); | => new NDArray(value); | ||||
| @@ -7,7 +7,7 @@ namespace Tensorflow.NumPy | |||||
| { | { | ||||
| public class NDArrayRender | public class NDArrayRender | ||||
| { | { | ||||
| public static string ToString(NDArray array) | |||||
| public static string ToString(NDArray array, int maxLength = 10) | |||||
| { | { | ||||
| Shape shape = array.shape; | Shape shape = array.shape; | ||||
| if (shape.IsScalar) | if (shape.IsScalar) | ||||
| @@ -15,12 +15,12 @@ namespace Tensorflow.NumPy | |||||
| var s = new StringBuilder(); | var s = new StringBuilder(); | ||||
| s.Append("array("); | s.Append("array("); | ||||
| Build(s, array); | |||||
| Build(s, array, maxLength); | |||||
| s.Append(")"); | s.Append(")"); | ||||
| return s.ToString(); | return s.ToString(); | ||||
| } | } | ||||
| static void Build(StringBuilder s, NDArray array) | |||||
| static void Build(StringBuilder s, NDArray array, int maxLength) | |||||
| { | { | ||||
| var shape = array.shape; | var shape = array.shape; | ||||
| @@ -35,11 +35,11 @@ namespace Tensorflow.NumPy | |||||
| var len = shape[0]; | var len = shape[0]; | ||||
| s.Append("["); | s.Append("["); | ||||
| if (len <= 10) | |||||
| if (len <= maxLength) | |||||
| { | { | ||||
| for (int i = 0; i < len; i++) | for (int i = 0; i < len; i++) | ||||
| { | { | ||||
| Build(s, array[i]); | |||||
| Build(s, array[i], maxLength); | |||||
| if (i < len - 1) | if (i < len - 1) | ||||
| { | { | ||||
| s.Append(", "); | s.Append(", "); | ||||
| @@ -49,9 +49,9 @@ namespace Tensorflow.NumPy | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| for (int i = 0; i < 5; i++) | |||||
| for (int i = 0; i < maxLength / 2; i++) | |||||
| { | { | ||||
| Build(s, array[i]); | |||||
| Build(s, array[i], maxLength); | |||||
| if (i < len - 1) | if (i < len - 1) | ||||
| { | { | ||||
| s.Append(", "); | s.Append(", "); | ||||
| @@ -62,9 +62,9 @@ namespace Tensorflow.NumPy | |||||
| s.Append(" ... "); | s.Append(" ... "); | ||||
| s.AppendLine(); | s.AppendLine(); | ||||
| for (int i = (int)len - 5; i < len; i++) | |||||
| for (int i = (int)len - maxLength / 2; i < len; i++) | |||||
| { | { | ||||
| Build(s, array[i]); | |||||
| Build(s, array[i], maxLength); | |||||
| if (i < len - 1) | if (i < len - 1) | ||||
| { | { | ||||
| s.Append(", "); | s.Append(", "); | ||||
| @@ -13,6 +13,10 @@ namespace Tensorflow.NumPy | |||||
| public static NDArray argmax(NDArray a, Axis? axis = null) | public static NDArray argmax(NDArray a, Axis? axis = null) | ||||
| => new NDArray(math_ops.argmax(a, axis ?? 0)); | => new NDArray(math_ops.argmax(a, axis ?? 0)); | ||||
| [AutoNumPy] | |||||
| public static NDArray argmin(NDArray a, Axis? axis = null) | |||||
| => new NDArray(math_ops.argmin(a, axis ?? 0)); | |||||
| [AutoNumPy] | [AutoNumPy] | ||||
| public static NDArray argsort(NDArray a, Axis? axis = null) | public static NDArray argsort(NDArray a, Axis? axis = null) | ||||
| => new NDArray(sort_ops.argsort(a, axis: axis ?? -1)); | => new NDArray(sort_ops.argsort(a, axis: axis ?? -1)); | ||||
| @@ -10,10 +10,10 @@ namespace Tensorflow.NumPy | |||||
| public partial class np | public partial class np | ||||
| { | { | ||||
| [AutoNumPy] | [AutoNumPy] | ||||
| public static NDArray amin(NDArray x, int axis = 0) => new NDArray(tf.arg_min(x, axis)); | |||||
| public static NDArray amin(NDArray x, int axis = 0) => new NDArray(tf.min(x, axis)); | |||||
| [AutoNumPy] | [AutoNumPy] | ||||
| public static NDArray amax(NDArray x, int axis = 0) => new NDArray(tf.math.argmax(x, axis)); | |||||
| public static NDArray amax(NDArray x, int axis = 0) => new NDArray(tf.max(x, axis)); | |||||
| [AutoNumPy] | [AutoNumPy] | ||||
| public static NDArray average(NDArray a, int axis = -1, NDArray? weights = null, bool returned = false) | public static NDArray average(NDArray a, int axis = -1, NDArray? weights = null, bool returned = false) | ||||
| @@ -49,9 +49,30 @@ namespace Tensorflow.NumPy | |||||
| [AutoNumPy] | [AutoNumPy] | ||||
| public static NDArray prod<T>(params T[] array) where T : unmanaged | public static NDArray prod<T>(params T[] array) where T : unmanaged | ||||
| => new NDArray(tf.reduce_prod(new NDArray(array))); | => new NDArray(tf.reduce_prod(new NDArray(array))); | ||||
| [AutoNumPy] | |||||
| public static NDArray dot(NDArray x1, NDArray x2, NDArray? axes = null, string? name = null) | |||||
| { | |||||
| //if axes mentioned | |||||
| if (axes != null) | |||||
| { | |||||
| return new NDArray(tf.dot_prod(x1, x2, axes, name)); | |||||
| } | |||||
| if (x1.shape.ndim > 1) | |||||
| { | |||||
| x1 = GetFlattenArray(x1); | |||||
| } | |||||
| if (x2.shape.ndim > 1) | |||||
| { | |||||
| x2 = GetFlattenArray(x2); | |||||
| } | |||||
| //if axes not mentioned, default 0,0 | |||||
| return new NDArray(tf.dot_prod(x1, x2, axes: new int[] { 0, 0 }, name)); | |||||
| } | |||||
| [AutoNumPy] | [AutoNumPy] | ||||
| public static NDArray power(NDArray x, NDArray y) => new NDArray(tf.pow(x, y)); | public static NDArray power(NDArray x, NDArray y) => new NDArray(tf.pow(x, y)); | ||||
| [AutoNumPy] | |||||
| public static NDArray square(NDArray x) => new NDArray(tf.square(x)); | |||||
| [AutoNumPy] | [AutoNumPy] | ||||
| public static NDArray sin(NDArray x) => new NDArray(math_ops.sin(x)); | public static NDArray sin(NDArray x) => new NDArray(math_ops.sin(x)); | ||||
| @@ -19,13 +19,14 @@ using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Common.Types; | |||||
| using Tensorflow.Keras.Saving.Common; | using Tensorflow.Keras.Saving.Common; | ||||
| using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| [JsonConverter(typeof(CustomizedShapeJsonConverter))] | [JsonConverter(typeof(CustomizedShapeJsonConverter))] | ||||
| public class Shape | |||||
| public class Shape : INestStructure<long> | |||||
| { | { | ||||
| public int ndim => _dims == null ? -1 : _dims.Length; | public int ndim => _dims == null ? -1 : _dims.Length; | ||||
| long[] _dims; | long[] _dims; | ||||
| @@ -41,6 +42,27 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| public NestType NestType => NestType.List; | |||||
| public int ShallowNestedCount => ndim; | |||||
| /// <summary> | |||||
| /// The total item count of depth 1 of the nested structure. | |||||
| /// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5. | |||||
| /// </summary> | |||||
| public int TotalNestedCount => ndim; | |||||
| public IEnumerable<long> Flatten() => dims.Select(x => x); | |||||
| public INestStructure<TOut> MapStructure<TOut>(Func<long, TOut> func) | |||||
| { | |||||
| return new NestList<TOut>(dims.Select(x => func(x))); | |||||
| } | |||||
| public Nest<long> AsNest() | |||||
| { | |||||
| return new NestList<long>(Flatten()).AsNest(); | |||||
| } | |||||
| #region https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/proposals/csharp-8.0/ranges | #region https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/proposals/csharp-8.0/ranges | ||||
| public int Length => ndim; | public int Length => ndim; | ||||
| public long[] Slice(int start, int length) | public long[] Slice(int start, int length) | ||||
| @@ -0,0 +1,22 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.NumPy; | |||||
| namespace Tensorflow.Operations.Initializers | |||||
| { | |||||
| /// <summary> | |||||
| /// An initializer specially used for debugging (to load weights from disk). | |||||
| /// </summary> | |||||
| class NpyLoadInitializer : IInitializer | |||||
| { | |||||
| string _path; | |||||
| public NpyLoadInitializer(string path) { _path = path; } | |||||
| public string ClassName => ""; | |||||
| public IDictionary<string, object> Config => new Dictionary<string, object>(); | |||||
| public Tensor Apply(InitializerArgs args) | |||||
| { | |||||
| return np.load(_path); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -53,13 +53,12 @@ public class Orthogonal : IInitializer | |||||
| // Compute the qr factorization | // Compute the qr factorization | ||||
| var (q, r) = tf.linalg.qr(a, full_matrices: false); | var (q, r) = tf.linalg.qr(a, full_matrices: false); | ||||
| // Make Q uniform | // Make Q uniform | ||||
| var d = tf.linalg.tensor_diag_part(r); | |||||
| var d = tf.linalg.tensor_diag_part(r.Single); | |||||
| q *= tf.sign(d); | q *= tf.sign(d); | ||||
| if (num_rows < num_cols) | if (num_rows < num_cols) | ||||
| { | { | ||||
| // q = tf.linalg.matrix_transpose(q); | |||||
| throw new NotImplementedException(""); | |||||
| q = array_ops.matrix_transpose(q); | |||||
| } | } | ||||
| return _gain * tf.reshape(q, shape); | return _gain * tf.reshape(q, shape); | ||||
| @@ -11,6 +11,7 @@ namespace Tensorflow | |||||
| /// Basic LSTM recurrent network cell. | /// Basic LSTM recurrent network cell. | ||||
| /// The implementation is based on: http://arxiv.org/abs/1409.2329. | /// The implementation is based on: http://arxiv.org/abs/1409.2329. | ||||
| /// </summary> | /// </summary> | ||||
| [Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] | |||||
| public class BasicLstmCell : LayerRnnCell | public class BasicLstmCell : LayerRnnCell | ||||
| { | { | ||||
| int _num_units; | int _num_units; | ||||
| @@ -88,7 +89,7 @@ namespace Tensorflow | |||||
| gate_inputs = nn_ops.bias_add(gate_inputs, _bias); | gate_inputs = nn_ops.bias_add(gate_inputs, _bias); | ||||
| // i = input_gate, j = new_input, f = forget_gate, o = output_gate | // i = input_gate, j = new_input, f = forget_gate, o = output_gate | ||||
| var tensors = array_ops.split(value: gate_inputs, num_split: 4, axis: one); | |||||
| var tensors = array_ops.split(value: gate_inputs, num_or_size_splits: 4, axis: one); | |||||
| var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]); | var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]); | ||||
| var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype); | var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype); | ||||
| @@ -20,6 +20,7 @@ using static Tensorflow.Binding; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| [Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] | |||||
| public class BasicRnnCell : LayerRnnCell | public class BasicRnnCell : LayerRnnCell | ||||
| { | { | ||||
| int _num_units; | int _num_units; | ||||
| @@ -19,6 +19,7 @@ using static Tensorflow.Binding; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| [Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] | |||||
| public class LayerRnnCell : RnnCell | public class LayerRnnCell : RnnCell | ||||
| { | { | ||||
| protected InputSpec inputSpec; | protected InputSpec inputSpec; | ||||
| @@ -16,10 +16,11 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using Tensorflow.Common.Types; | |||||
| using Tensorflow.Keras; | using Tensorflow.Keras; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | |||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Layers; | |||||
| using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
| using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
| using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
| @@ -50,7 +51,8 @@ namespace Tensorflow | |||||
| /// matching structure of Tensors having shape `[batch_size].concatenate(s)` | /// matching structure of Tensors having shape `[batch_size].concatenate(s)` | ||||
| /// for each `s` in `self.batch_size`. | /// for each `s` in `self.batch_size`. | ||||
| /// </summary> | /// </summary> | ||||
| public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell | |||||
| [Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] | |||||
| public abstract class RnnCell : ILayer, IRnnCell | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Attribute that indicates whether the cell is a TF RNN cell, due the slight | /// Attribute that indicates whether the cell is a TF RNN cell, due the slight | ||||
| @@ -142,7 +144,7 @@ namespace Tensorflow | |||||
| throw new NotImplementedException("_zero_state_tensors"); | throw new NotImplementedException("_zero_state_tensors"); | ||||
| } | } | ||||
| public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| public Tensors Apply(Tensors inputs, Tensors state = null, bool? is_training = false, IOptionalArgs? optional_args = null) | |||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| } | } | ||||
| @@ -173,5 +175,18 @@ namespace Tensorflow | |||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| } | } | ||||
| public (Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public Tensors GetInitialState(Tensors inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public INestStructure<long> StateSize => throw new NotImplementedException(); | |||||
| public INestStructure<long> OutputSize => throw new NotImplementedException(); | |||||
| public bool IsTFRnnCell => throw new NotImplementedException(); | |||||
| public bool SupportOptionalArgs => throw new NotImplementedException(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -15,9 +15,11 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using Google.Protobuf; | using Google.Protobuf; | ||||
| using Google.Protobuf.Collections; | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Functions; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using static Tensorflow.OpDef.Types; | using static Tensorflow.OpDef.Types; | ||||
| @@ -387,9 +389,13 @@ namespace Tensorflow | |||||
| case "list(type)": | case "list(type)": | ||||
| attr_value.List.Type.AddRange((value as IList<TF_DataType>).Select(x => _MakeType(x, attr_def))); | attr_value.List.Type.AddRange((value as IList<TF_DataType>).Select(x => _MakeType(x, attr_def))); | ||||
| break; | break; | ||||
| case "list(float)": | |||||
| if (value != null) | |||||
| attr_value.List.F.AddRange((value as IEnumerable<float>).ToArray()); | |||||
| break; | |||||
| case "list(int)": | case "list(int)": | ||||
| if (value != null) | if (value != null) | ||||
| attr_value.List.I.AddRange((value as int[]).Select(x => Convert.ToInt64(x))); | |||||
| attr_value.List.I.AddRange((value as IEnumerable<int>).Select(x => Convert.ToInt64(x))); | |||||
| break; | break; | ||||
| case "bool": | case "bool": | ||||
| attr_value.B = (bool)value; | attr_value.B = (bool)value; | ||||
| @@ -420,6 +426,15 @@ namespace Tensorflow | |||||
| case "list(shape)": | case "list(shape)": | ||||
| attr_value.List.Shape.AddRange((value as Shape[]).Select(x => _MakeShape(x, attr_def))); | attr_value.List.Shape.AddRange((value as Shape[]).Select(x => _MakeShape(x, attr_def))); | ||||
| break; | break; | ||||
| case "func": | |||||
| attr_value.Func = _MakeFunc(value, attr_def.Name); | |||||
| break; | |||||
| case "list(func)": | |||||
| attr_value.List.Func.AddRange(_MakeFuncList(value, attr_def.Name)); | |||||
| break; | |||||
| case "list(string)": | |||||
| attr_value.List.S.AddRange((value as IEnumerable<string>).Select(x => ByteString.CopyFromUtf8(x))); | |||||
| break; | |||||
| default: | default: | ||||
| throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos."); | throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos."); | ||||
| } | } | ||||
| @@ -427,6 +442,47 @@ namespace Tensorflow | |||||
| return attr_value; | return attr_value; | ||||
| } | } | ||||
| private NameAttrList _MakeFunc(object func, string arg_name) | |||||
| { | |||||
| if(func is NameAttrList attrList) | |||||
| { | |||||
| return attrList; | |||||
| } | |||||
| NameAttrList fn_attr; | |||||
| if(func is string funcStr) | |||||
| { | |||||
| fn_attr = new NameAttrList() { Name = funcStr }; | |||||
| } | |||||
| else if(func is ConcreteFunction concrete) | |||||
| { | |||||
| concrete.AddTograph(ops.get_default_graph()); | |||||
| fn_attr = concrete.AsNameAttrList; | |||||
| } | |||||
| else if(func is EagerDefinedFunction eager) | |||||
| { | |||||
| eager.AddToGraph(ops.get_default_graph()); | |||||
| fn_attr = new NameAttrList() { Name = eager.Name }; | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new TypeError($"Don't know how to convert {func} to a func for argument {arg_name}"); | |||||
| } | |||||
| return fn_attr; | |||||
| } | |||||
| private List<NameAttrList> _MakeFuncList(object funcList, string arg_name) | |||||
| { | |||||
| List<NameAttrList> res = new List<NameAttrList>(); | |||||
| if(funcList is IEnumerable enumerable) | |||||
| { | |||||
| foreach(var func in enumerable) | |||||
| { | |||||
| res.Add(_MakeFunc(func, arg_name)); | |||||
| } | |||||
| } | |||||
| return res; | |||||
| } | |||||
| private bool _IsListParameter(ArgDef arg) | private bool _IsListParameter(ArgDef arg) | ||||
| { | { | ||||
| if (!String.IsNullOrEmpty(arg.NumberAttr)) | if (!String.IsNullOrEmpty(arg.NumberAttr)) | ||||