| @@ -2,7 +2,7 @@ | |||
| <PropertyGroup> | |||
| <OutputType>Exe</OutputType> | |||
| <TargetFramework>netcoreapp3.1</TargetFramework> | |||
| <TargetFramework>net5.0</TargetFramework> | |||
| <RootNamespace>Tensorflow</RootNamespace> | |||
| <AssemblyName>Tensorflow</AssemblyName> | |||
| <Platforms>AnyCPU;x64</Platforms> | |||
| @@ -43,7 +43,7 @@ namespace Tensorflow | |||
| /// </summary> | |||
| public partial class c_api | |||
| { | |||
| public const string TensorFlowLibName = @"D:\Projects\tensorflow-haiping\bazel-bin\tensorflow\tensorflow"; | |||
| public const string TensorFlowLibName = "tensorflow"; | |||
| public static string StringPiece(IntPtr handle) | |||
| { | |||
| @@ -24,6 +24,30 @@ namespace Tensorflow | |||
| { | |||
| string_ops ops = new string_ops(); | |||
| /// <summary> | |||
| /// Converts all uppercase characters into their respective lowercase replacements. | |||
| /// </summary> | |||
| /// <param name="input"></param> | |||
| /// <param name="encoding"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public Tensor lower(Tensor input, string encoding = "", string name = null) | |||
| => ops.lower(input: input, encoding: encoding, name: name); | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <param name="input"></param> | |||
| /// <param name="pattern"></param> | |||
| /// <param name="rewrite"></param> | |||
| /// <param name="replace_global"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public Tensor regex_replace(Tensor input, string pattern, string rewrite, | |||
| bool replace_global = true, string name = null) | |||
| => ops.regex_replace(input, pattern, rewrite, | |||
| replace_global: replace_global, name: name); | |||
| /// <summary> | |||
| /// Return substrings from `Tensor` of strings. | |||
| /// </summary> | |||
| @@ -14,6 +14,7 @@ namespace Tensorflow | |||
| public class DatasetV2 : IDatasetV2 | |||
| { | |||
| protected dataset_ops ops = new dataset_ops(); | |||
| public string[] class_names { get; set; } | |||
| public Tensor variant_tensor { get; set; } | |||
| public TensorSpec[] structure { get; set; } | |||
| @@ -54,7 +55,7 @@ namespace Tensorflow | |||
| public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs) | |||
| => new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs); | |||
| public IDatasetV2 map(Func<Tensor, Tensor> map_func, | |||
| public IDatasetV2 map(Func<Tensors, Tensors> map_func, | |||
| bool use_inter_op_parallelism = true, | |||
| bool preserve_cardinality = true, | |||
| bool use_legacy_function = false) | |||
| @@ -64,7 +65,7 @@ namespace Tensorflow | |||
| preserve_cardinality: preserve_cardinality, | |||
| use_legacy_function: use_legacy_function); | |||
| public IDatasetV2 map(Func<Tensors, Tensors> map_func, int num_parallel_calls = -1) | |||
| public IDatasetV2 map(Func<Tensors, Tensors> map_func, int num_parallel_calls) | |||
| => new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls); | |||
| public IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func) | |||
| @@ -6,6 +6,8 @@ namespace Tensorflow | |||
| { | |||
| public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)> | |||
| { | |||
| string[] class_names { get; set; } | |||
| Tensor variant_tensor { get; set; } | |||
| TensorShape[] output_shapes { get; } | |||
| @@ -62,13 +64,13 @@ namespace Tensorflow | |||
| IDatasetV2 optimize(string[] optimizations, string[] optimization_configs); | |||
| IDatasetV2 map(Func<Tensor, Tensor> map_func, | |||
| IDatasetV2 map(Func<Tensors, Tensors> map_func, | |||
| bool use_inter_op_parallelism = true, | |||
| bool preserve_cardinality = true, | |||
| bool use_legacy_function = false); | |||
| IDatasetV2 map(Func<Tensors, Tensors> map_func, | |||
| int num_parallel_calls = -1); | |||
| int num_parallel_calls); | |||
| IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func); | |||
| @@ -10,7 +10,7 @@ namespace Tensorflow | |||
| public class MapDataset : UnaryDataset | |||
| { | |||
| public MapDataset(IDatasetV2 input_dataset, | |||
| Func<Tensor, Tensor> map_func, | |||
| Func<Tensors, Tensors> map_func, | |||
| bool use_inter_op_parallelism = true, | |||
| bool preserve_cardinality = false, | |||
| bool use_legacy_function = false) : base(input_dataset) | |||
| @@ -15,7 +15,7 @@ namespace Tensorflow.Framework.Models | |||
| if (_shape.ndim == 0) | |||
| throw new ValueError("Unbatching a tensor is only supported for rank >= 1"); | |||
| return new TensorSpec(_shape.dims[1..], _dtype); | |||
| return new TensorSpec(_shape.dims.Skip(1).ToArray(), _dtype); | |||
| } | |||
| public TensorSpec _batch(int dim = -1) | |||
| @@ -30,7 +30,7 @@ namespace Tensorflow.Gradients | |||
| var shape = new TensorShape(image.shape.Skip(1).Take(2).ToArray()); | |||
| Tensor image_shape = null; | |||
| if (shape.is_fully_defined()) | |||
| image_shape = constant_op.constant(image.shape[1..3]); | |||
| image_shape = constant_op.constant(image.shape.Skip(1).Take(2).ToArray()); | |||
| else | |||
| image_shape = array_ops.shape(image)["1:3"]; | |||
| @@ -0,0 +1,10 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class PreprocessingLayerArgs : LayerArgs | |||
| { | |||
| } | |||
| } | |||
| @@ -0,0 +1,15 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class TextVectorizationArgs : PreprocessingLayerArgs | |||
| { | |||
| public Func<Tensor, Tensor> Standardize { get; set; } | |||
| public string Split { get; set; } = "standardize"; | |||
| public int MaxTokens { get; set; } = -1; | |||
| public string OutputMode { get; set; } = "int"; | |||
| public int OutputSequenceLength { get; set; } = -1; | |||
| } | |||
| } | |||
| @@ -20,6 +20,54 @@ namespace Tensorflow | |||
| { | |||
| public class string_ops | |||
| { | |||
| public Tensor lower(Tensor input, string encoding = "", string name = null) | |||
| { | |||
| if (tf.Context.executing_eagerly()) | |||
| { | |||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "StringLower", name, | |||
| null, | |||
| input, encoding); | |||
| return results[0]; | |||
| } | |||
| var _op = tf.OpDefLib._apply_op_helper("StringLower", name: name, args: new | |||
| { | |||
| input, | |||
| encoding | |||
| }); | |||
| return _op.output; | |||
| } | |||
| public Tensor regex_replace(Tensor input, string pattern, string rewrite, | |||
| bool replace_global = true, string name = null) | |||
| { | |||
| if (tf.Context.executing_eagerly()) | |||
| { | |||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "StaticRegexReplace", name, | |||
| null, | |||
| input, | |||
| "pattern", pattern, | |||
| "rewrite", rewrite, | |||
| "replace_global", replace_global); | |||
| return results[0]; | |||
| } | |||
| var _op = tf.OpDefLib._apply_op_helper("StaticRegexReplace", name: name, args: new | |||
| { | |||
| input, | |||
| pattern, | |||
| rewrite, | |||
| replace_global | |||
| }); | |||
| return _op.output; | |||
| } | |||
| /// <summary> | |||
| /// Return substrings from `Tensor` of strings. | |||
| /// </summary> | |||
| @@ -0,0 +1,18 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| namespace Tensorflow.Keras.Engine | |||
| { | |||
| public class CombinerPreprocessingLayer : Layer | |||
| { | |||
| PreprocessingLayerArgs args; | |||
| public CombinerPreprocessingLayer(PreprocessingLayerArgs args) | |||
| : base(args) | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -39,7 +39,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
| dataset = slice_inputs(indices_dataset, inputs); | |||
| } | |||
| Tensor permutation(Tensor tensor) | |||
| Tensors permutation(Tensors tensor) | |||
| { | |||
| var indices = math_ops.range(num_samples, dtype: dtypes.int64); | |||
| if (args.Shuffle) | |||
| @@ -82,7 +82,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
| .Select(x => gen_array_ops.gather_v2(x, indices, 0)) | |||
| .ToArray(); | |||
| return new Tensors(results); | |||
| }); | |||
| }, -1); | |||
| return dataset.with_options(new DatasetOptions { }); | |||
| } | |||
| @@ -62,8 +62,8 @@ namespace Tensorflow.Keras.Engine | |||
| { | |||
| var y_t_rank = y_t.rank; | |||
| var y_p_rank = y_p.rank; | |||
| var y_t_last_dim = y_t.shape[^1]; | |||
| var y_p_last_dim = y_p.shape[^1]; | |||
| var y_t_last_dim = y_t.shape[y_t.shape.Length - 1]; | |||
| var y_p_last_dim = y_p.shape[y_p.shape.Length - 1]; | |||
| bool is_binary = y_p_last_dim == 1; | |||
| bool is_sparse_categorical = (y_t_rank < y_p_rank || y_t_last_dim == 1) && y_p_last_dim > 1; | |||
| @@ -14,6 +14,7 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System.Linq; | |||
| using System.Collections.Generic; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Layers; | |||
| @@ -103,7 +104,7 @@ namespace Tensorflow.Keras.Engine | |||
| if (set_inputs) | |||
| { | |||
| // If an input layer (placeholder) is available. | |||
| outputs = layer.InboundNodes[^1].Outputs; | |||
| outputs = layer.InboundNodes.Last().Outputs; | |||
| inputs = layer_utils.get_source_inputs(outputs[0]); | |||
| built = true; | |||
| _has_explicit_input_shape = true; | |||
| @@ -11,6 +11,7 @@ using Tensorflow.Keras.Metrics; | |||
| using Tensorflow.Keras.Models; | |||
| using Tensorflow.Keras.Optimizers; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.Keras.Utils; | |||
| namespace Tensorflow.Keras | |||
| { | |||
| @@ -27,6 +28,7 @@ namespace Tensorflow.Keras | |||
| public OptimizerApi optimizers { get; } = new OptimizerApi(); | |||
| public MetricsApi metrics { get; } = new MetricsApi(); | |||
| public ModelsApi models { get; } = new ModelsApi(); | |||
| public KerasUtils utils { get; } = new KerasUtils(); | |||
| public Sequential Sequential(List<ILayer> layers = null, | |||
| string name = null) | |||
| @@ -73,7 +75,7 @@ namespace Tensorflow.Keras | |||
| Tensor tensor = null) | |||
| { | |||
| if (batch_input_shape != null) | |||
| shape = batch_input_shape.dims[1..]; | |||
| shape = batch_input_shape.dims.Skip(1).ToArray(); | |||
| var args = new InputLayerArgs | |||
| { | |||
| @@ -42,7 +42,7 @@ namespace Tensorflow.Keras.Layers | |||
| if (BatchInputShape != null) | |||
| { | |||
| args.BatchSize = BatchInputShape.dims[0]; | |||
| args.InputShape = BatchInputShape.dims[1..]; | |||
| args.InputShape = BatchInputShape.dims.Skip(1).ToArray(); | |||
| } | |||
| // moved to base class | |||
| @@ -9,6 +9,8 @@ namespace Tensorflow.Keras.Layers | |||
| { | |||
| public partial class LayersApi | |||
| { | |||
| public Preprocessing preprocessing { get; } = new Preprocessing(); | |||
| /// <summary> | |||
| /// Functional interface for the batch normalization layer. | |||
| /// http://arxiv.org/abs/1502.03167 | |||
| @@ -0,0 +1,20 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| public class TextVectorization : CombinerPreprocessingLayer | |||
| { | |||
| TextVectorizationArgs args; | |||
| public TextVectorization(TextVectorizationArgs args) | |||
| : base(args) | |||
| { | |||
| args.DType = TF_DataType.TF_STRING; | |||
| // string standardize = "lower_and_strip_punctuation", | |||
| } | |||
| } | |||
| } | |||
| @@ -1,4 +1,5 @@ | |||
| using System; | |||
| using System.Linq; | |||
| using Tensorflow.Framework; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| @@ -45,7 +46,7 @@ namespace Tensorflow.Keras.Layers | |||
| return array_ops.reshape(inputs, new[] { batch_dim, -1 }); | |||
| } | |||
| var non_batch_dims = ((int[])input_shape)[1..]; | |||
| var non_batch_dims = ((int[])input_shape).Skip(1).ToArray(); | |||
| var num = 1; | |||
| if (non_batch_dims.Length > 0) | |||
| { | |||
| @@ -37,7 +37,7 @@ namespace Tensorflow.Keras.Layers | |||
| public override TensorShape ComputeOutputShape(TensorShape input_shape) | |||
| { | |||
| if (input_shape.dims[1..].Contains(-1)) | |||
| if (input_shape.dims.Skip(1).Contains(-1)) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| } | |||
| @@ -1,4 +1,5 @@ | |||
| using System; | |||
| using System.Linq; | |||
| namespace Tensorflow.Keras.Preprocessings | |||
| { | |||
| @@ -17,18 +18,21 @@ namespace Tensorflow.Keras.Preprocessings | |||
| float validation_split, | |||
| string subset) | |||
| { | |||
| if (string.IsNullOrEmpty(subset)) | |||
| return (samples, labels); | |||
| var num_val_samples = Convert.ToInt32(samples.Length * validation_split); | |||
| if (subset == "training") | |||
| { | |||
| Console.WriteLine($"Using {samples.Length - num_val_samples} files for training."); | |||
| samples = samples[..^num_val_samples]; | |||
| labels = labels[..^num_val_samples]; | |||
| samples = samples.Take(samples.Length - num_val_samples).ToArray(); | |||
| labels = labels.Take(labels.Length - num_val_samples).ToArray(); | |||
| } | |||
| else if (subset == "validation") | |||
| { | |||
| Console.WriteLine($"Using {num_val_samples} files for validation."); | |||
| samples = samples[(samples.Length - num_val_samples)..]; | |||
| labels = labels[(labels.Length - num_val_samples)..]; | |||
| samples = samples.Skip(samples.Length - num_val_samples).ToArray(); | |||
| labels = labels.Skip(labels.Length - num_val_samples).ToArray(); | |||
| } | |||
| else | |||
| throw new NotImplementedException(""); | |||
| @@ -21,40 +21,41 @@ namespace Tensorflow.Keras.Preprocessings | |||
| /// file_paths, labels, class_names | |||
| /// </returns> | |||
| public (string[], int[], string[]) index_directory(string directory, | |||
| string labels, | |||
| string[] formats = null, | |||
| string[] class_names = null, | |||
| bool shuffle = true, | |||
| int? seed = null, | |||
| bool follow_links = false) | |||
| { | |||
| var labels = new List<int>(); | |||
| var label_list = new List<int>(); | |||
| var file_paths = new List<string>(); | |||
| var class_dirs = Directory.GetDirectories(directory); | |||
| class_names = class_dirs.Select(x => x.Split(Path.DirectorySeparatorChar)[^1]).ToArray(); | |||
| class_names = class_dirs.Select(x => x.Split(Path.DirectorySeparatorChar).Last()).ToArray(); | |||
| for (var label = 0; label < class_dirs.Length; label++) | |||
| { | |||
| var files = Directory.GetFiles(class_dirs[label]); | |||
| file_paths.AddRange(files); | |||
| labels.AddRange(Enumerable.Range(0, files.Length).Select(x => label)); | |||
| label_list.AddRange(Enumerable.Range(0, files.Length).Select(x => label)); | |||
| } | |||
| var return_labels = labels.Select(x => x).ToArray(); | |||
| var return_labels = label_list.Select(x => x).ToArray(); | |||
| var return_file_paths = file_paths.Select(x => x).ToArray(); | |||
| if (shuffle) | |||
| { | |||
| if (!seed.HasValue) | |||
| seed = np.random.randint((long)1e6); | |||
| var random_index = np.arange(labels.Count); | |||
| var random_index = np.arange(label_list.Count); | |||
| var rng = np.random.RandomState(seed.Value); | |||
| rng.shuffle(random_index); | |||
| var index = random_index.ToArray<int>(); | |||
| for (int i = 0; i < labels.Count; i++) | |||
| for (int i = 0; i < label_list.Count; i++) | |||
| { | |||
| return_labels[i] = labels[index[i]]; | |||
| return_labels[i] = label_list[index[i]]; | |||
| return_file_paths[i] = file_paths[index[i]]; | |||
| } | |||
| } | |||
| @@ -1,4 +1,7 @@ | |||
| using Tensorflow.Keras.Preprocessings; | |||
| using System; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Layers; | |||
| using Tensorflow.Keras.Preprocessings; | |||
| namespace Tensorflow.Keras | |||
| { | |||
| @@ -6,5 +9,18 @@ namespace Tensorflow.Keras | |||
| { | |||
| public Sequence sequence => new Sequence(); | |||
| public DatasetUtils dataset_utils => new DatasetUtils(); | |||
| public TextVectorization TextVectorization(Func<Tensor, Tensor> standardize = null, | |||
| string split = "standardize", | |||
| int max_tokens = -1, | |||
| string output_mode = "int", | |||
| int output_sequence_length = -1) => new TextVectorization(new TextVectorizationArgs | |||
| { | |||
| Standardize = standardize, | |||
| Split = split, | |||
| MaxTokens = max_tokens, | |||
| OutputMode = output_mode, | |||
| OutputSequenceLength = output_sequence_length | |||
| }); | |||
| } | |||
| } | |||
| @@ -43,6 +43,7 @@ namespace Tensorflow.Keras | |||
| num_channels = 3; | |||
| var (image_paths, label_list, class_name_list) = keras.preprocessing.dataset_utils.index_directory(directory, | |||
| labels, | |||
| formats: WHITELIST_FORMATS, | |||
| class_names: class_names, | |||
| shuffle: shuffle, | |||
| @@ -64,13 +65,30 @@ namespace Tensorflow.Keras | |||
| string[] class_names = null, | |||
| int batch_size = 32, | |||
| bool shuffle = true, | |||
| int max_length = -1, | |||
| int? seed = null, | |||
| float validation_split = 0.2f, | |||
| string subset = null) | |||
| string subset = null, | |||
| bool follow_links = false) | |||
| { | |||
| var (file_paths, label_list, class_name_list) = dataset_utils.index_directory( | |||
| directory, | |||
| labels, | |||
| formats: new[] { ".txt" }, | |||
| class_names: class_names, | |||
| shuffle: shuffle, | |||
| seed: seed, | |||
| follow_links: follow_links); | |||
| return null; | |||
| (file_paths, label_list) = dataset_utils.get_training_or_validation_split( | |||
| file_paths, label_list, validation_split, subset); | |||
| var dataset = paths_and_labels_to_dataset(file_paths, label_list, label_mode, class_name_list.Length); | |||
| if (shuffle) | |||
| dataset = dataset.shuffle(batch_size * 8, seed: seed); | |||
| dataset = dataset.batch(batch_size); | |||
| dataset.class_names = class_name_list; | |||
| return dataset; | |||
| } | |||
| } | |||
| } | |||
| @@ -1,4 +1,5 @@ | |||
| using System; | |||
| using System.IO; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras | |||
| @@ -34,5 +35,31 @@ namespace Tensorflow.Keras | |||
| // img.set_shape((image_size[0], image_size[1], num_channels)); | |||
| return img; | |||
| } | |||
| public IDatasetV2 paths_and_labels_to_dataset(string[] image_paths, | |||
| int[] labels, | |||
| string label_mode, | |||
| int num_classes, | |||
| int max_length = -1) | |||
| { | |||
| var path_ds = tf.data.Dataset.from_tensor_slices(image_paths); | |||
| var string_ds = path_ds.map(x => path_to_string_content(x, max_length)); | |||
| if (label_mode == "int") | |||
| { | |||
| var label_ds = dataset_utils.labels_to_dataset(labels, label_mode, num_classes); | |||
| string_ds = tf.data.Dataset.zip(string_ds, label_ds); | |||
| } | |||
| return string_ds; | |||
| } | |||
| Tensor path_to_string_content(Tensor path, int max_length) | |||
| { | |||
| var txt = tf.io.read_file(path); | |||
| if (max_length > -1) | |||
| txt = tf.strings.substr(txt, 0, max_length); | |||
| return txt; | |||
| } | |||
| } | |||
| } | |||
| @@ -6,7 +6,7 @@ | |||
| <LangVersion>8.0</LangVersion> | |||
| <RootNamespace>Tensorflow.Keras</RootNamespace> | |||
| <Platforms>AnyCPU;x64</Platforms> | |||
| <Version>0.4.0</Version> | |||
| <Version>0.4.1</Version> | |||
| <Authors>Haiping Chen</Authors> | |||
| <Product>Keras for .NET</Product> | |||
| <Copyright>Apache 2.0, Haiping Chen 2020</Copyright> | |||
| @@ -34,8 +34,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||
| <RepositoryType>Git</RepositoryType> | |||
| <SignAssembly>true</SignAssembly> | |||
| <AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> | |||
| <AssemblyVersion>0.4.0.0</AssemblyVersion> | |||
| <FileVersion>0.4.0.0</FileVersion> | |||
| <AssemblyVersion>0.4.1.0</AssemblyVersion> | |||
| <FileVersion>0.4.1.0</FileVersion> | |||
| <PackageLicenseFile>LICENSE</PackageLicenseFile> | |||
| </PropertyGroup> | |||
| @@ -2,7 +2,7 @@ | |||
| <PropertyGroup> | |||
| <OutputType>Exe</OutputType> | |||
| <TargetFramework>netcoreapp3.1</TargetFramework> | |||
| <TargetFramework>net5.0</TargetFramework> | |||
| <Platforms>AnyCPU;x64</Platforms> | |||
| </PropertyGroup> | |||
| @@ -56,7 +56,7 @@ Set ENV `BAZEL_VC=C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\ | |||
| 1. Build static library | |||
| `bazel build --output_base=C:/tmp/tfcompilation build --config=opt //tensorflow:tensorflow` | |||
| `bazel build --output_base=C:/tmp/tfcompilation --config=opt //tensorflow:tensorflow` | |||
| 2. Build pip package | |||
| @@ -70,6 +70,16 @@ Set ENV `BAZEL_VC=C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\ | |||
| `pip install C:/tmp/tensorflow_pkg/tensorflow-1.15.0-cp36-cp36m-win_amd64.whl` | |||
| ### Build from source for MacOS | |||
| ```shell | |||
| $ cd /usr/local/lib/bazel/bin | |||
| $ curl -LO https://release.bazel.build/3.7.2/release/bazel-3.7.2-darwin-x86_64 | |||
| $ chmod +x bazel-3.7.2-darwin-x86_64 | |||
| $ cd ~/Projects/tensorflow | |||
| $ bazel build --config=opt //tensorflow:tensorflow | |||
| ``` | |||
| ### Build specific version for tf.net | |||
| https://github.com/SciSharp/tensorflow | |||