| @@ -1,4 +1,5 @@ | |||||
| using System; | using System; | ||||
| using Tensorflow.Keras; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -7,6 +8,8 @@ namespace Tensorflow | |||||
| { | { | ||||
| static void Main(string[] args) | static void Main(string[] args) | ||||
| { | { | ||||
| tf.UseKeras<KerasInterface>(); | |||||
| var diag = new Diagnostician(); | var diag = new Diagnostician(); | ||||
| // diag.Diagnose(@"D:\memory.txt"); | // diag.Diagnose(@"D:\memory.txt"); | ||||
| @@ -58,6 +58,12 @@ namespace Tensorflow | |||||
| NDArray l2_regularizer = null, bool fast = true, string name = null) | NDArray l2_regularizer = null, bool fast = true, string name = null) | ||||
| => ops.matrix_solve_ls(matrix, rhs, l2_regularizer: l2_regularizer, fast: fast, name: name); | => ops.matrix_solve_ls(matrix, rhs, l2_regularizer: l2_regularizer, fast: fast, name: name); | ||||
| public Tensors qr(Tensor input, bool full_matrices = true, string name = null) | |||||
| => ops.qr(input, full_matrices: full_matrices, name: name); | |||||
| public Tensor tensor_diag_part(Tensor input, string name = null) | |||||
| => gen_array_ops.diag_part(input, name: name); | |||||
| public Tensor tensordot(Tensor x, Tensor y, NDArray axes, string name = null) | public Tensor tensordot(Tensor x, Tensor y, NDArray axes, string name = null) | ||||
| => math_ops.tensordot(x, y, axes, name: name); | => math_ops.tensordot(x, y, axes, name: name); | ||||
| } | } | ||||
| @@ -39,6 +39,12 @@ namespace Tensorflow | |||||
| int? seed = null, | int? seed = null, | ||||
| string name = null) => random_ops.random_normal(shape, mean, stddev, dtype, seed, name); | string name = null) => random_ops.random_normal(shape, mean, stddev, dtype, seed, name); | ||||
| public Tensor stateless_normal(Shape shape, | |||||
| float mean = 0.0f, | |||||
| float stddev = 1.0f, | |||||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||||
| string name = null) => stateless_random_ops.stateless_random_normal(shape, mean, stddev, dtype, name: name); | |||||
| /// <summary> | /// <summary> | ||||
| /// Outputs random values from a truncated normal distribution. | /// Outputs random values from a truncated normal distribution. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -0,0 +1,11 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras | |||||
| { | |||||
| public interface IInitializersApi | |||||
| { | |||||
| IInitializer Orthogonal(float gain = 1.0f, int? seed = null); | |||||
| } | |||||
| } | |||||
| @@ -8,5 +8,6 @@ namespace Tensorflow.Keras | |||||
| public interface IKerasApi | public interface IKerasApi | ||||
| { | { | ||||
| public ILayersApi layers { get; } | public ILayersApi layers { get; } | ||||
| public IInitializersApi initializers { get; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -109,6 +109,7 @@ namespace Tensorflow.NumPy | |||||
| TF_DataType.TF_INT8 => Render(array.ToArray<sbyte>(), array.shape), | TF_DataType.TF_INT8 => Render(array.ToArray<sbyte>(), array.shape), | ||||
| TF_DataType.TF_INT32 => Render(array.ToArray<int>(), array.shape), | TF_DataType.TF_INT32 => Render(array.ToArray<int>(), array.shape), | ||||
| TF_DataType.TF_INT64 => Render(array.ToArray<long>(), array.shape), | TF_DataType.TF_INT64 => Render(array.ToArray<long>(), array.shape), | ||||
| TF_DataType.TF_UINT64 => Render(array.ToArray<ulong>(), array.shape), | |||||
| TF_DataType.TF_FLOAT => Render(array.ToArray<float>(), array.shape), | TF_DataType.TF_FLOAT => Render(array.ToArray<float>(), array.shape), | ||||
| TF_DataType.TF_DOUBLE => Render(array.ToArray<double>(), array.shape), | TF_DataType.TF_DOUBLE => Render(array.ToArray<double>(), array.shape), | ||||
| _ => Render(array.ToArray<byte>(), array.shape) | _ => Render(array.ToArray<byte>(), array.shape) | ||||
| @@ -1,32 +1,62 @@ | |||||
| using System; | |||||
| /***************************************************************************** | |||||
| Copyright 2023 Haiping Chen. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| using System.Linq; | using System.Linq; | ||||
| using static Tensorflow.TensorShapeProto.Types; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Operations.Initializers | |||||
| namespace Tensorflow.Operations.Initializers; | |||||
| public class Orthogonal : IInitializer | |||||
| { | { | ||||
| public class Orthogonal : IInitializer | |||||
| float _gain = 0f; | |||||
| int? _seed; | |||||
| public Orthogonal(float gain = 1.0f, int? seed = null) | |||||
| { | { | ||||
| float _gain = 0f; | |||||
| _gain = gain; | |||||
| _seed = seed; | |||||
| } | |||||
| public Orthogonal(float gain = 1.0f, int? seed = null) | |||||
| { | |||||
| public Tensor Apply(InitializerArgs args) | |||||
| { | |||||
| return _generate_init_val(args.Shape, args.DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : args.DType); | |||||
| } | |||||
| } | |||||
| private Tensor _generate_init_val(Shape shape, TF_DataType dtype) | |||||
| { | |||||
| var num_rows = 1L; | |||||
| foreach (var dim in shape.dims.Take(shape.ndim - 1)) | |||||
| num_rows *= dim; | |||||
| var num_cols = shape.dims.Last(); | |||||
| var flat_shape = (Math.Max(num_cols, num_rows), Math.Min(num_cols, num_rows)); | |||||
| public Tensor Apply(InitializerArgs args) | |||||
| { | |||||
| return _generate_init_val(args.Shape, args.DType); | |||||
| } | |||||
| var a = tf.random.stateless_normal(flat_shape, dtype: dtype); | |||||
| // Compute the qr factorization | |||||
| var (q, r) = tf.linalg.qr(a, full_matrices: false); | |||||
| // Make Q uniform | |||||
| var d = tf.linalg.tensor_diag_part(r); | |||||
| q *= tf.sign(d); | |||||
| private Tensor _generate_init_val(Shape shape, TF_DataType dtype) | |||||
| if (num_rows < num_cols) | |||||
| { | { | ||||
| var num_rows = 1L; | |||||
| foreach (var dim in shape.dims.Take(shape.ndim - 1)) | |||||
| num_rows *= dim; | |||||
| var num_cols = shape.dims.Last(); | |||||
| var flat_shape = (Math.Max(num_cols, num_rows), Math.Min(num_cols, num_rows)); | |||||
| // q = tf.linalg.matrix_transpose(q); | |||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| } | } | ||||
| return _gain * tf.reshape(q, shape); | |||||
| } | } | ||||
| } | } | ||||
| @@ -113,6 +113,9 @@ namespace Tensorflow | |||||
| public static Tensor diag(Tensor diagonal, string name = null) | public static Tensor diag(Tensor diagonal, string name = null) | ||||
| => tf.Context.ExecuteOp("Diag", name, new ExecuteOpArgs(diagonal)); | => tf.Context.ExecuteOp("Diag", name, new ExecuteOpArgs(diagonal)); | ||||
| public static Tensor diag_part(Tensor diagonal, string name = null) | |||||
| => tf.Context.ExecuteOp("DiagPart", name, new ExecuteOpArgs(diagonal)); | |||||
| public static Tensor expand_dims(Tensor input, int axis, string name = null) | public static Tensor expand_dims(Tensor input, int axis, string name = null) | ||||
| => tf.Context.ExecuteOp("ExpandDims", name, new ExecuteOpArgs(input, axis) | => tf.Context.ExecuteOp("ExpandDims", name, new ExecuteOpArgs(input, axis) | ||||
| .SetAttributes(new { dim = axis })); | .SetAttributes(new { dim = axis })); | ||||
| @@ -13,7 +13,10 @@ | |||||
| See the License for the specific language governing permissions and | See the License for the specific language governing permissions and | ||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using static Tensorflow.ApiDef.Types; | |||||
| using System.Reflection; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using System.Xml.Linq; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -85,6 +88,15 @@ namespace Tensorflow | |||||
| int? seed2 = 0, string name = null) | int? seed2 = 0, string name = null) | ||||
| => tf.Context.ExecuteOp("TruncatedNormal", name, new ExecuteOpArgs(shape) | => tf.Context.ExecuteOp("TruncatedNormal", name, new ExecuteOpArgs(shape) | ||||
| .SetAttributes(new { dtype, seed = seed ?? 0, seed2 = seed2 ?? 0 })); | .SetAttributes(new { dtype, seed = seed ?? 0, seed2 = seed2 ?? 0 })); | ||||
| public static Tensor stateless_random_normal_v2(Tensor shape, Tensor key, Tensor counter, | |||||
| int alg, TF_DataType dtype, string name = null) | |||||
| => tf.Context.ExecuteOp("StatelessRandomNormalV2", name, | |||||
| new ExecuteOpArgs(shape, key, counter, alg) | |||||
| .SetAttributes(new { dtype })); | |||||
| public static Tensors stateless_random_get_key_counter(int[] seed, string name = null) | |||||
| => tf.Context.ExecuteOp("StatelessRandomGetKeyCounter", name, | |||||
| new ExecuteOpArgs(seed)); | |||||
| public static Tensor multinomial(Tensor logits, int num_samples, int? seed = 0, | public static Tensor multinomial(Tensor logits, int num_samples, int? seed = 0, | ||||
| int? seed2 = 0, TF_DataType output_dtype = TF_DataType.TF_INT64, string name = null) | int? seed2 = 0, TF_DataType output_dtype = TF_DataType.TF_INT64, string name = null) | ||||
| @@ -129,5 +129,12 @@ namespace Tensorflow | |||||
| lower, | lower, | ||||
| adjoint | adjoint | ||||
| })); | })); | ||||
| public Tensors qr(Tensor input, bool full_matrices = false, string name = null) | |||||
| => tf.Context.ExecuteOp("Qr", name, | |||||
| new ExecuteOpArgs(input).SetAttributes(new | |||||
| { | |||||
| full_matrices | |||||
| })); | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,62 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2023 Haiping Chen. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using static Tensorflow.ApiDef.Types; | |||||
| using System.Reflection; | |||||
| using static Tensorflow.Binding; | |||||
| using System; | |||||
| namespace Tensorflow; | |||||
| public class stateless_random_ops | |||||
| { | |||||
| public static Tensor stateless_random_normal(Shape shape, | |||||
| float mean = 0.0f, | |||||
| float stddev = 1.0f, | |||||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||||
| int[]? seed = null, | |||||
| string name = null) | |||||
| { | |||||
| return tf_with(ops.name_scope(name, "stateless_random_normal", new { shape, seed, mean, stddev }), scope => | |||||
| { | |||||
| name = scope; | |||||
| var shape_tensor = _ShapeTensor(shape); | |||||
| var mean_tensor = ops.convert_to_tensor(mean, dtype: dtype, name: "mean"); | |||||
| var stddev_tensor = ops.convert_to_tensor(stddev, dtype: dtype, name: "stddev"); | |||||
| if (seed == null) | |||||
| { | |||||
| seed = new[] { new Random().Next(), 0 }; | |||||
| } | |||||
| var (key, counter) = _get_key_counter(seed, 3); | |||||
| var rnd = gen_random_ops.stateless_random_normal_v2(shape: shape_tensor, key: key, counter: counter, dtype: dtype, alg: 3); | |||||
| var value = math_ops.add(rnd * stddev, mean_tensor, name: name); | |||||
| // tensor_util.maybe_set_static_shape(value, shape) | |||||
| return value; | |||||
| }); | |||||
| } | |||||
| private static Tensor _ShapeTensor(int[] shape) | |||||
| { | |||||
| return ops.convert_to_tensor(shape, name: "shape"); | |||||
| } | |||||
| private static (Tensor, Tensor) _get_key_counter(int[] seed, int alg) | |||||
| { | |||||
| var results = gen_random_ops.stateless_random_get_key_counter(seed); | |||||
| return (results[0], results[1]); | |||||
| } | |||||
| } | |||||
| @@ -67,7 +67,10 @@ namespace Tensorflow | |||||
| public void UseKeras<T>() where T : IKerasApi, new() | public void UseKeras<T>() where T : IKerasApi, new() | ||||
| { | { | ||||
| keras = new T(); | |||||
| if (keras == null) | |||||
| { | |||||
| keras = new T(); | |||||
| } | |||||
| } | } | ||||
| public string VERSION => c_api.StringPiece(c_api.TF_Version()); | public string VERSION => c_api.StringPiece(c_api.TF_Version()); | ||||
| @@ -16,18 +16,20 @@ | |||||
| using Tensorflow.Operations.Initializers; | using Tensorflow.Operations.Initializers; | ||||
| namespace Tensorflow.Keras | |||||
| namespace Tensorflow.Keras; | |||||
| public partial class InitializersApi : IInitializersApi | |||||
| { | { | ||||
| public class Initializers | |||||
| /// <summary> | |||||
| /// He normal initializer. | |||||
| /// </summary> | |||||
| /// <param name="seed"></param> | |||||
| /// <returns></returns> | |||||
| public IInitializer he_normal(int? seed = null) | |||||
| { | { | ||||
| /// <summary> | |||||
| /// He normal initializer. | |||||
| /// </summary> | |||||
| /// <param name="seed"></param> | |||||
| /// <returns></returns> | |||||
| public IInitializer he_normal(int? seed = null) | |||||
| { | |||||
| return new VarianceScaling(factor: 2.0f, mode: "fan_in", seed: seed); | |||||
| } | |||||
| return new VarianceScaling(factor: 2.0f, mode: "fan_in", seed: seed); | |||||
| } | } | ||||
| public IInitializer Orthogonal(float gain = 1.0f, int? seed = null) | |||||
| => new Orthogonal(gain: gain, seed: seed); | |||||
| } | } | ||||
| @@ -18,7 +18,7 @@ namespace Tensorflow.Keras | |||||
| public class KerasInterface : IKerasApi | public class KerasInterface : IKerasApi | ||||
| { | { | ||||
| public KerasDataset datasets { get; } = new KerasDataset(); | public KerasDataset datasets { get; } = new KerasDataset(); | ||||
| public Initializers initializers { get; } = new Initializers(); | |||||
| public IInitializersApi initializers { get; } = new InitializersApi(); | |||||
| public Regularizers regularizers { get; } = new Regularizers(); | public Regularizers regularizers { get; } = new Regularizers(); | ||||
| public ILayersApi layers { get; } = new LayersApi(); | public ILayersApi layers { get; } = new LayersApi(); | ||||
| public LossesApi losses { get; } = new LossesApi(); | public LossesApi losses { get; } = new LossesApi(); | ||||
| @@ -1,5 +1,6 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using System; | using System; | ||||
| using Tensorflow.Keras; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.Keras.UnitTest | namespace TensorFlowNET.Keras.UnitTest | ||||
| @@ -9,6 +10,8 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
| [TestInitialize] | [TestInitialize] | ||||
| public void TestInit() | public void TestInit() | ||||
| { | { | ||||
| tf.UseKeras<KerasInterface>(); | |||||
| if (!tf.executing_eagerly()) | if (!tf.executing_eagerly()) | ||||
| tf.enable_eager_execution(); | tf.enable_eager_execution(); | ||||
| tf.Context.ensure_initialized(); | tf.Context.ensure_initialized(); | ||||
| @@ -0,0 +1,20 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using TensorFlowNET.Keras.UnitTest; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Keras.UnitTest; | |||||
| [TestClass] | |||||
| public class InitializerTest : EagerModeTestBase | |||||
| { | |||||
| [TestMethod] | |||||
| public void Orthogonal() | |||||
| { | |||||
| var initializer = tf.keras.initializers.Orthogonal(); | |||||
| var values = initializer.Apply(new InitializerArgs((2, 2))); | |||||
| } | |||||
| } | |||||