| @@ -16,6 +16,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using Tensorflow.IO; | using Tensorflow.IO; | ||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -59,6 +60,10 @@ namespace Tensorflow | |||||
| string name = null) | string name = null) | ||||
| => image_ops_impl.resize_images(images, size, method, preserve_aspect_ratio, antialias, name); | => image_ops_impl.resize_images(images, size, method, preserve_aspect_ratio, antialias, name); | ||||
| public Tensor resize_images_v2(Tensor images, TensorShape size, string method = ResizeMethod.BILINEAR, bool preserve_aspect_ratio = false, bool antialias = false, | |||||
| string name = null) | |||||
| => image_ops_impl.resize_images(images, tf.constant(size.dims), method, preserve_aspect_ratio, antialias, name); | |||||
| public Tensor resize_images_with_pad(Tensor image, int target_height, int target_width, string method, bool antialias) | public Tensor resize_images_with_pad(Tensor image, int target_height, int target_width, string method, bool antialias) | ||||
| => image_ops_impl.resize_images_with_pad(image, target_height, target_width, method, antialias); | => image_ops_impl.resize_images_with_pad(image, target_height, target_width, method, antialias); | ||||
| @@ -160,7 +165,7 @@ namespace Tensorflow | |||||
| int ratio = 1, | int ratio = 1, | ||||
| bool fancy_upscaling = true, | bool fancy_upscaling = true, | ||||
| bool try_recover_truncated = false, | bool try_recover_truncated = false, | ||||
| float acceptable_fraction = 1, | |||||
| int acceptable_fraction = 1, | |||||
| string dct_method = "", | string dct_method = "", | ||||
| string name = null) | string name = null) | ||||
| => gen_image_ops.decode_jpeg(contents, channels: channels, ratio: ratio, | => gen_image_ops.decode_jpeg(contents, channels: channels, ratio: ratio, | ||||
| @@ -376,6 +376,9 @@ namespace Tensorflow.Eager | |||||
| case TF_AttrType.TF_ATTR_INT: | case TF_AttrType.TF_ATTR_INT: | ||||
| c_api.TFE_OpSetAttrInt(op, key, Convert.ToInt64(value)); | c_api.TFE_OpSetAttrInt(op, key, Convert.ToInt64(value)); | ||||
| break; | break; | ||||
| case TF_AttrType.TF_ATTR_FLOAT: | |||||
| c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value)); | |||||
| break; | |||||
| case TF_AttrType.TF_ATTR_SHAPE: | case TF_AttrType.TF_ATTR_SHAPE: | ||||
| var dims = (value as int[]).Select(x => (long)x).ToArray(); | var dims = (value as int[]).Select(x => (long)x).ToArray(); | ||||
| c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status.Handle); | c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status.Handle); | ||||
| @@ -176,6 +176,9 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFE_OpSetAttrInt(SafeOpHandle op, string attr_name, long value); | public static extern void TFE_OpSetAttrInt(SafeOpHandle op, string attr_name, long value); | ||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TFE_OpSetAttrFloat(SafeOpHandle op, string attr_name, float value); | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| /// </summary> | /// </summary> | ||||
| @@ -0,0 +1,43 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.IO; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.Preprocessings | |||||
| { | |||||
| public partial class DatasetUtils | |||||
| { | |||||
| /// <summary> | |||||
| /// Potentially restict samples & labels to a training or validation split. | |||||
| /// </summary> | |||||
| /// <param name="samples"></param> | |||||
| /// <param name="labels"></param> | |||||
| /// <param name="validation_split"></param> | |||||
| /// <param name="subset"></param> | |||||
| /// <returns></returns> | |||||
| public (T1[], T2[]) get_training_or_validation_split<T1, T2>(T1[] samples, | |||||
| T2[] labels, | |||||
| float validation_split, | |||||
| string subset) | |||||
| { | |||||
| var num_val_samples = Convert.ToInt32(samples.Length * validation_split); | |||||
| if (subset == "training") | |||||
| { | |||||
| Console.WriteLine($"Using {samples.Length - num_val_samples} files for training."); | |||||
| samples = samples[..^num_val_samples]; | |||||
| labels = labels[..^num_val_samples]; | |||||
| } | |||||
| else if (subset == "validation") | |||||
| { | |||||
| Console.WriteLine($"Using {num_val_samples} files for validation."); | |||||
| samples = samples[samples.Length..]; | |||||
| labels = labels[samples.Length..]; | |||||
| } | |||||
| else | |||||
| throw new NotImplementedException(""); | |||||
| return (samples, labels); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,5 +1,8 @@ | |||||
| using System; | |||||
| using NumSharp; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.IO; | |||||
| using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.Keras.Preprocessings | namespace Tensorflow.Keras.Preprocessings | ||||
| @@ -20,14 +23,45 @@ namespace Tensorflow.Keras.Preprocessings | |||||
| /// file_paths, labels, class_names | /// file_paths, labels, class_names | ||||
| /// </returns> | /// </returns> | ||||
| public (string[], int[], string[]) index_directory(string directory, | public (string[], int[], string[]) index_directory(string directory, | ||||
| string labels, | |||||
| string[] formats, | |||||
| string class_names = null, | |||||
| string[] formats = null, | |||||
| string[] class_names = null, | |||||
| bool shuffle = true, | bool shuffle = true, | ||||
| int? seed = null, | int? seed = null, | ||||
| bool follow_links = false) | bool follow_links = false) | ||||
| { | { | ||||
| throw new NotImplementedException(""); | |||||
| var labels = new List<int>(); | |||||
| var file_paths = new List<string>(); | |||||
| var class_dirs = Directory.GetDirectories(directory); | |||||
| class_names = class_dirs.Select(x => x.Split(Path.DirectorySeparatorChar)[^1]).ToArray(); | |||||
| for (var label = 0; label < class_dirs.Length; label++) | |||||
| { | |||||
| var files = Directory.GetFiles(class_dirs[label]); | |||||
| file_paths.AddRange(files); | |||||
| labels.AddRange(Enumerable.Range(0, files.Length).Select(x => label)); | |||||
| } | |||||
| var return_labels = new int[labels.Count]; | |||||
| var return_file_paths = new string[file_paths.Count]; | |||||
| if (shuffle) | |||||
| { | |||||
| if (!seed.HasValue) | |||||
| seed = np.random.randint((long)1e6); | |||||
| var random_index = np.arange(labels.Count); | |||||
| var rng = np.random.RandomState(seed.Value); | |||||
| rng.shuffle(random_index); | |||||
| var index = random_index.ToArray<int>(); | |||||
| for (int i = 0; i< labels.Count; i++) | |||||
| { | |||||
| return_labels[i] = labels[index[i]]; | |||||
| return_file_paths[i] = file_paths[index[i]]; | |||||
| } | |||||
| } | |||||
| return (return_file_paths, return_labels, class_names); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -28,7 +28,7 @@ namespace Tensorflow.Keras | |||||
| public Tensor image_dataset_from_directory(string directory, | public Tensor image_dataset_from_directory(string directory, | ||||
| string labels = "inferred", | string labels = "inferred", | ||||
| string label_mode = "int", | string label_mode = "int", | ||||
| string class_names = null, | |||||
| string[] class_names = null, | |||||
| string color_mode = "rgb", | string color_mode = "rgb", | ||||
| int batch_size = 32, | int batch_size = 32, | ||||
| TensorShape image_size = null, | TensorShape image_size = null, | ||||
| @@ -44,13 +44,15 @@ namespace Tensorflow.Keras | |||||
| num_channels = 3; | num_channels = 3; | ||||
| // C:/Users/haipi/.keras/datasets/flower_photos | // C:/Users/haipi/.keras/datasets/flower_photos | ||||
| var (image_paths, label_list, class_name_list) = tf.keras.preprocessing.dataset_utils.index_directory(directory, | var (image_paths, label_list, class_name_list) = tf.keras.preprocessing.dataset_utils.index_directory(directory, | ||||
| labels, | |||||
| WHITELIST_FORMATS, | |||||
| formats: WHITELIST_FORMATS, | |||||
| class_names: class_names, | class_names: class_names, | ||||
| shuffle: shuffle, | shuffle: shuffle, | ||||
| seed: seed, | seed: seed, | ||||
| follow_links: follow_links); | follow_links: follow_links); | ||||
| (image_paths, label_list) = tf.keras.preprocessing.dataset_utils.get_training_or_validation_split(image_paths, label_list, validation_split, subset); | |||||
| paths_and_labels_to_dataset(image_paths, image_size, num_channels, label_list, label_mode, class_name_list.Length, interpolation); | |||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,32 @@ | |||||
| using System; | |||||
| using System.Globalization; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Keras | |||||
| { | |||||
| public partial class Preprocessing | |||||
| { | |||||
| public Tensor paths_and_labels_to_dataset(string[] image_paths, | |||||
| TensorShape image_size, | |||||
| int num_channels, | |||||
| int[] labels, | |||||
| string label_mode, | |||||
| int num_classes, | |||||
| string interpolation) | |||||
| { | |||||
| foreach (var image_path in image_paths) | |||||
| path_to_image(image_path, image_size, num_channels, interpolation); | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| Tensor path_to_image(string path, TensorShape image_size, int num_channels, string interpolation) | |||||
| { | |||||
| var img = tf.io.read_file(path); | |||||
| img = tf.image.decode_image( | |||||
| img, channels: num_channels, expand_animations: false); | |||||
| img = tf.image.resize_images_v2(img, image_size, method: interpolation); | |||||
| return img; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -86,7 +86,7 @@ namespace Tensorflow | |||||
| var guarded_assert = cond(condition, false_assert, true_assert, name: "AssertGuard"); | var guarded_assert = cond(condition, false_assert, true_assert, name: "AssertGuard"); | ||||
| return guarded_assert[0].op; | |||||
| return guarded_assert == null ? null : guarded_assert[0].op; | |||||
| }); | }); | ||||
| } | } | ||||
| @@ -423,8 +423,6 @@ namespace Tensorflow | |||||
| return true_fn() as Tensor; | return true_fn() as Tensor; | ||||
| else | else | ||||
| return false_fn() as Tensor; | return false_fn() as Tensor; | ||||
| return null; | |||||
| } | } | ||||
| // Add the Switch to the graph. | // Add the Switch to the graph. | ||||
| @@ -507,8 +505,6 @@ namespace Tensorflow | |||||
| return true_fn() as Tensor[]; | return true_fn() as Tensor[]; | ||||
| else | else | ||||
| return false_fn() as Tensor[]; | return false_fn() as Tensor[]; | ||||
| return null; | |||||
| } | } | ||||
| // Add the Switch to the graph. | // Add the Switch to the graph. | ||||
| @@ -66,14 +66,24 @@ namespace Tensorflow | |||||
| int ratio = 1, | int ratio = 1, | ||||
| bool fancy_upscaling = true, | bool fancy_upscaling = true, | ||||
| bool try_recover_truncated = false, | bool try_recover_truncated = false, | ||||
| float acceptable_fraction = 1, | |||||
| int acceptable_fraction = 1, | |||||
| string dct_method = "", | string dct_method = "", | ||||
| string name = null) | string name = null) | ||||
| { | { | ||||
| // Add nodes to the TensorFlow graph. | // Add nodes to the TensorFlow graph. | ||||
| if (tf.Context.executing_eagerly()) | if (tf.Context.executing_eagerly()) | ||||
| { | { | ||||
| throw new NotImplementedException("decode_jpeg"); | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
| "DecodeJpeg", name, | |||||
| null, | |||||
| contents, | |||||
| "channels", channels, | |||||
| "ratio", ratio, | |||||
| "fancy_upscaling", fancy_upscaling, | |||||
| "try_recover_truncated", try_recover_truncated, | |||||
| "acceptable_fraction", acceptable_fraction, | |||||
| "dct_method", dct_method); | |||||
| return results[0]; | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -171,17 +181,42 @@ namespace Tensorflow | |||||
| "half_pixel_centers", half_pixel_centers); | "half_pixel_centers", half_pixel_centers); | ||||
| return results[0]; | return results[0]; | ||||
| } | } | ||||
| else | |||||
| var _op = tf.OpDefLib._apply_op_helper("ResizeBilinear", name: name, args: new | |||||
| { | { | ||||
| var _op = tf.OpDefLib._apply_op_helper("ResizeBilinear", name: name, args: new | |||||
| { | |||||
| images, | |||||
| size, | |||||
| align_corners | |||||
| }); | |||||
| images, | |||||
| size, | |||||
| align_corners | |||||
| }); | |||||
| return _op.outputs[0]; | |||||
| return _op.outputs[0]; | |||||
| } | |||||
| public static Tensor resize_bicubic(Tensor images, | |||||
| Tensor size, | |||||
| bool align_corners = false, | |||||
| bool half_pixel_centers = false, | |||||
| string name = null) | |||||
| { | |||||
| if (tf.Context.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
| "ResizeBicubic", name, | |||||
| null, | |||||
| images, size, | |||||
| "align_corners", align_corners, | |||||
| "half_pixel_centers", half_pixel_centers); | |||||
| return results[0]; | |||||
| } | } | ||||
| var _op = tf.OpDefLib._apply_op_helper("ResizeBicubic", name: name, args: new | |||||
| { | |||||
| images, | |||||
| size, | |||||
| align_corners | |||||
| }); | |||||
| return _op.outputs[0]; | |||||
| } | } | ||||
| public static Tensor resize_nearest_neighbor<Tsize>(Tensor images, Tsize size, bool align_corners = false, | public static Tensor resize_nearest_neighbor<Tsize>(Tensor images, Tsize size, bool align_corners = false, | ||||
| @@ -24715,13 +24715,12 @@ namespace Tensorflow.Operations | |||||
| /// <remarks> | /// <remarks> | ||||
| /// Input images can be of different types but output images are always float. | /// Input images can be of different types but output images are always float. | ||||
| /// </remarks> | /// </remarks> | ||||
| public static Tensor resize_bicubic (Tensor images, Tensor size, bool? align_corners = null, string name = "ResizeBicubic") | |||||
| public static Tensor resize_bicubic (Tensor images, Tensor size, bool align_corners = false, bool half_pixel_centers = false, string name = "ResizeBicubic") | |||||
| { | { | ||||
| var dict = new Dictionary<string, object>(); | var dict = new Dictionary<string, object>(); | ||||
| dict["images"] = images; | dict["images"] = images; | ||||
| dict["size"] = size; | dict["size"] = size; | ||||
| if (align_corners.HasValue) | |||||
| dict["align_corners"] = align_corners.Value; | |||||
| dict["align_corners"] = align_corners; | |||||
| var op = tf.OpDefLib._apply_op_helper("ResizeBicubic", name: name, keywords: dict); | var op = tf.OpDefLib._apply_op_helper("ResizeBicubic", name: name, keywords: dict); | ||||
| return op.output; | return op.output; | ||||
| } | } | ||||
| @@ -686,18 +686,20 @@ or rank = 4. Had rank = {0}", rank)); | |||||
| } else if (images.TensorShape.ndim != 4) | } else if (images.TensorShape.ndim != 4) | ||||
| throw new ValueError("\'images\' must have either 3 or 4 dimensions."); | throw new ValueError("\'images\' must have either 3 or 4 dimensions."); | ||||
| var _hw_ = images.TensorShape.as_list(); | |||||
| var (height, width) = (images.dims[1], images.dims[2]); | |||||
| try | try | ||||
| { | { | ||||
| size = ops.convert_to_tensor(size, dtypes.int32, name: "size"); | size = ops.convert_to_tensor(size, dtypes.int32, name: "size"); | ||||
| } catch (Exception ex) | |||||
| } | |||||
| catch (Exception ex) | |||||
| { | { | ||||
| if (ex is TypeError || ex is ValueError) | if (ex is TypeError || ex is ValueError) | ||||
| throw new ValueError("\'size\' must be a 1-D int32 Tensor"); | throw new ValueError("\'size\' must be a 1-D int32 Tensor"); | ||||
| else | else | ||||
| throw; | throw; | ||||
| } | } | ||||
| if (!size.TensorShape.is_compatible_with(new [] {2})) | if (!size.TensorShape.is_compatible_with(new [] {2})) | ||||
| throw new ValueError(@"\'size\' must be a 1-D Tensor of 2 elements: | throw new ValueError(@"\'size\' must be a 1-D Tensor of 2 elements: | ||||
| new_height, new_width"); | new_height, new_width"); | ||||
| @@ -736,9 +738,9 @@ new_height, new_width"); | |||||
| bool x_null = true; | bool x_null = true; | ||||
| if (skip_resize_if_same) | if (skip_resize_if_same) | ||||
| { | { | ||||
| foreach (int x in new [] {new_width_const, _hw_[2], new_height_const, _hw_[1]}) | |||||
| foreach (int x in new [] {new_width_const, width, new_height_const, height}) | |||||
| { | { | ||||
| if (_hw_[2] != new_width_const && _hw_[1] == new_height_const) | |||||
| if (width != new_width_const && height == new_height_const) | |||||
| { | { | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -753,8 +755,8 @@ new_height, new_width"); | |||||
| } | } | ||||
| images = resizer_fn(images, size); | images = resizer_fn(images, size); | ||||
| images.set_shape(new TensorShape(new int[] {0, new_height_const, new_width_const, 0})); | |||||
| // images.set_shape(new TensorShape(new int[] { -1, new_height_const, new_width_const, -1 })); | |||||
| if (!is_batch) | if (!is_batch) | ||||
| images = array_ops.squeeze(images, axis: new int[] {0}); | images = array_ops.squeeze(images, axis: new int[] {0}); | ||||
| @@ -792,17 +794,20 @@ new_height, new_width"); | |||||
| if (antialias) | if (antialias) | ||||
| return resize_with_scale_and_translate("triangle"); | return resize_with_scale_and_translate("triangle"); | ||||
| else | else | ||||
| return gen_image_ops.resize_bilinear( | |||||
| images_t, new_size, true); | |||||
| return gen_image_ops.resize_bilinear(images_t, | |||||
| new_size, | |||||
| half_pixel_centers: true); | |||||
| else if (method == ResizeMethod.NEAREST_NEIGHBOR) | else if (method == ResizeMethod.NEAREST_NEIGHBOR) | ||||
| return gen_image_ops.resize_nearest_neighbor( | |||||
| images_t, new_size, true); | |||||
| return gen_image_ops.resize_nearest_neighbor(images_t, | |||||
| new_size, | |||||
| half_pixel_centers: true); | |||||
| else if (method == ResizeMethod.BICUBIC) | else if (method == ResizeMethod.BICUBIC) | ||||
| if (antialias) | if (antialias) | ||||
| return resize_with_scale_and_translate("keyscubic"); | return resize_with_scale_and_translate("keyscubic"); | ||||
| else | else | ||||
| return gen_ops.resize_bicubic( | |||||
| images_t, new_size, true); | |||||
| return gen_image_ops.resize_bicubic(images_t, | |||||
| new_size, | |||||
| half_pixel_centers: true); | |||||
| else if (method == ResizeMethod.AREA) | else if (method == ResizeMethod.AREA) | ||||
| return gen_ops.resize_area(images_t, new_size); | return gen_ops.resize_area(images_t, new_size); | ||||
| else if (Array.Exists(scale_and_translate_methods, method => method == method)) | else if (Array.Exists(scale_and_translate_methods, method => method == method)) | ||||
| @@ -2078,9 +2083,8 @@ new_height, new_width"); | |||||
| return tf_with(ops.name_scope(name, "is_jpeg"), scope => | return tf_with(ops.name_scope(name, "is_jpeg"), scope => | ||||
| { | { | ||||
| var substr = tf.strings.substr(contents, 0, 3); | var substr = tf.strings.substr(contents, 0, 3); | ||||
| var jpg = Encoding.UTF8.GetString(new byte[] { 0xff, 0xd8, 0xff }); | |||||
| var jpg_tensor = tf.constant(jpg); | |||||
| var result = math_ops.equal(substr, jpg_tensor, name: name); | |||||
| var jpg = tf.constant(new byte[] { 0xff, 0xd8, 0xff }, TF_DataType.TF_STRING); | |||||
| var result = math_ops.equal(substr, jpg, name: name); | |||||
| return result; | return result; | ||||
| }); | }); | ||||
| } | } | ||||
| @@ -5,7 +5,7 @@ | |||||
| <AssemblyName>TensorFlow.NET</AssemblyName> | <AssemblyName>TensorFlow.NET</AssemblyName> | ||||
| <RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
| <TargetTensorFlow>2.2.0</TargetTensorFlow> | <TargetTensorFlow>2.2.0</TargetTensorFlow> | ||||
| <Version>0.20.0-preview4</Version> | |||||
| <Version>0.20.0-preview5</Version> | |||||
| <LangVersion>8.0</LangVersion> | <LangVersion>8.0</LangVersion> | ||||
| <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | ||||
| <Company>SciSharp STACK</Company> | <Company>SciSharp STACK</Company> | ||||
| @@ -79,7 +79,7 @@ Please be patient, we're working hard on missing functions, providing full tenso | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="Google.Protobuf" Version="3.11.4" /> | <PackageReference Include="Google.Protobuf" Version="3.11.4" /> | ||||
| <PackageReference Include="NumSharp.Lite" Version="0.1.7" /> | |||||
| <PackageReference Include="NumSharp.Lite" Version="0.1.8" /> | |||||
| <PackageReference Include="Protobuf.Text" Version="0.4.0" /> | <PackageReference Include="Protobuf.Text" Version="0.4.0" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| @@ -143,7 +143,13 @@ namespace Tensorflow | |||||
| public bool is_compatible_with(TensorShape shape2) | public bool is_compatible_with(TensorShape shape2) | ||||
| { | { | ||||
| throw new NotImplementedException("TensorShape is_compatible_with"); | |||||
| if(dims != null && shape2.dims != null) | |||||
| { | |||||
| if (dims.Length != shape2.dims.Length) | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | } | ||||
| public void assert_has_rank(int rank) | public void assert_has_rank(int rank) | ||||
| @@ -231,17 +231,7 @@ namespace Tensorflow | |||||
| if (tensor.GetType() == typeof(EagerTensor)) | if (tensor.GetType() == typeof(EagerTensor)) | ||||
| { | { | ||||
| int[] dims = {}; | |||||
| foreach (int dim in tensor.numpy()) | |||||
| if (dim != 1) | |||||
| { | |||||
| dims[dims.Length] = dim; | |||||
| } else | |||||
| { | |||||
| // -1 == Unknown | |||||
| dims[dims.Length] = -1; | |||||
| } | |||||
| return new TensorShape(dims); | |||||
| return new TensorShape(tensor.numpy().ToArray<int>()); | |||||
| } | } | ||||
| if (tensor.TensorShape.ndim == 0) | if (tensor.TensorShape.ndim == 0) | ||||
| @@ -45,6 +45,19 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
| Assert.AreEqual("", g._name_stack); | Assert.AreEqual("", g._name_stack); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void NameScopeInEagerMode() | |||||
| { | |||||
| tf.enable_eager_execution(); | |||||
| tf_with(new ops.NameScope("scope"), scope => | |||||
| { | |||||
| string name = scope; | |||||
| }); | |||||
| tf.compat.v1.disable_eager_execution(); | |||||
| } | |||||
| [TestMethod, Ignore("Unimplemented Usage")] | [TestMethod, Ignore("Unimplemented Usage")] | ||||
| public void NestedNameScope_Using() | public void NestedNameScope_Using() | ||||
| { | { | ||||
| @@ -46,7 +46,6 @@ | |||||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.7.0" /> | <PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.7.0" /> | ||||
| <PackageReference Include="MSTest.TestAdapter" Version="2.1.2" /> | <PackageReference Include="MSTest.TestAdapter" Version="2.1.2" /> | ||||
| <PackageReference Include="MSTest.TestFramework" Version="2.1.2" /> | <PackageReference Include="MSTest.TestFramework" Version="2.1.2" /> | ||||
| <PackageReference Include="NumSharp.Lite" Version="0.1.7" /> | |||||
| <PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.0" /> | <PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.0" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||