| @@ -17,7 +17,9 @@ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -76,7 +78,14 @@ namespace Tensorflow | |||
| public Tensor concat(IList<Tensor> values, int axis, string name = "concat") | |||
| { | |||
| if (values.Count == 1) | |||
| throw new NotImplementedException("tf.concat length is 1"); | |||
| { | |||
| return tf_with(ops.name_scope(name), scope => | |||
| { | |||
| var tensor = ops.convert_to_tensor(axis, name: "concat_dim", dtype: dtypes.int32); | |||
| Debug.Assert(tensor.TensorShape.ndim == 0); | |||
| return identity(values[0], name: scope); | |||
| }); | |||
| } | |||
| return gen_array_ops.concat_v2(values.ToArray(), axis, name: name); | |||
| } | |||
| @@ -111,7 +120,7 @@ namespace Tensorflow | |||
| /// <param name="input"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public static Tensor identity(Tensor input, string name = null) | |||
| public Tensor identity(Tensor input, string name = null) | |||
| => array_ops.identity(input, name: name); | |||
| /// <summary> | |||
| @@ -150,10 +159,10 @@ namespace Tensorflow | |||
| /// <param name="axis"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public static Tensor reverse(Tensor tensor, int[] axis, string name = null) | |||
| public Tensor reverse(Tensor tensor, int[] axis, string name = null) | |||
| => gen_array_ops.reverse(tensor, axis, name: name); | |||
| public static Tensor reverse(Tensor tensor, Tensor axis, string name = null) | |||
| public Tensor reverse(Tensor tensor, Tensor axis, string name = null) | |||
| => gen_array_ops.reverse(tensor, axis, name: name); | |||
| /// <summary> | |||
| @@ -277,5 +286,14 @@ namespace Tensorflow | |||
| /// <returns>A `Tensor` with all elements set to zero.</returns> | |||
| public Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) | |||
| => array_ops.zeros_like(tensor, dtype: dtype, name: name, optimize: optimize); | |||
| /// <summary> | |||
| /// Stops gradient computation. | |||
| /// </summary> | |||
| /// <param name="x"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public Tensor stop_gradient(Tensor x, string name = null) | |||
| => gen_array_ops.stop_gradient(x, name: name); | |||
| } | |||
| } | |||
| @@ -434,11 +434,14 @@ namespace Tensorflow | |||
| public Tensor reduce_sum(Tensor input, int? axis = null, int? reduction_indices = null, | |||
| bool keepdims = false, string name = null) | |||
| { | |||
| if(!axis.HasValue && reduction_indices.HasValue) | |||
| if (!axis.HasValue && reduction_indices.HasValue && !keepdims) | |||
| return math_ops.reduce_sum(input, reduction_indices.Value); | |||
| else if (axis.HasValue && !reduction_indices.HasValue) | |||
| else if (axis.HasValue && !reduction_indices.HasValue && !keepdims) | |||
| return math_ops.reduce_sum(input, axis.Value); | |||
| return math_ops.reduce_sum(input, keepdims: keepdims, name: name); | |||
| else if (axis.HasValue && !reduction_indices.HasValue && keepdims) | |||
| return math_ops.reduce_sum(input, keepdims: keepdims, axis: axis.Value, name: name); | |||
| else | |||
| return math_ops.reduce_sum(input, keepdims: keepdims, name: name); | |||
| } | |||
| public Tensor reduce_sum(Tensor input, TensorShape axis, int? reduction_indices = null, | |||
| @@ -471,6 +474,9 @@ namespace Tensorflow | |||
| public Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null) | |||
| => math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name, reduction_indices: reduction_indices); | |||
| public Tensor reduce_mean(Tensor[] input_tensors, int axis, bool keepdims = false, string name = null) | |||
| => math_ops.reduce_mean(input_tensors, axis: axis, keepdims: keepdims, name: name); | |||
| public Tensor round(Tensor x, string name = null) | |||
| => gen_math_ops.round(x, name: name); | |||
| @@ -65,5 +65,10 @@ namespace Tensorflow | |||
| public void set_random_seed(int seed) | |||
| => ops.get_default_graph().seed = seed; | |||
| public Tensor multinomial(Tensor logits, int num_samples, int? seed = null, | |||
| string name = null, TF_DataType output_dtype = TF_DataType.DtInvalid) | |||
| => random_ops.multinomial(logits, num_samples, seed: seed, | |||
| name: name, output_dtype: output_dtype); | |||
| } | |||
| } | |||
| @@ -15,6 +15,7 @@ | |||
| ******************************************************************************/ | |||
| using System.Collections.Generic; | |||
| using Tensorflow.Keras.Optimizers; | |||
| using Tensorflow.Train; | |||
| namespace Tensorflow | |||
| @@ -73,6 +74,26 @@ namespace Tensorflow | |||
| public CheckpointState get_checkpoint_state(string checkpoint_dir, string latest_filename = null) | |||
| => checkpoint_management.get_checkpoint_state(checkpoint_dir, latest_filename: latest_filename); | |||
| public Tensor polynomial_decay(float learning_rate, | |||
| RefVariable global_step, | |||
| float decay_steps, | |||
| float end_learning_rate = 0.0001f, | |||
| float power = 1.0f, | |||
| bool cycle = false, | |||
| string name = null) | |||
| { | |||
| var decayed = new PolynomialDecay(learning_rate, | |||
| decay_steps, | |||
| end_learning_rate: end_learning_rate, | |||
| power: power, | |||
| cycle: cycle, | |||
| name: name); | |||
| var decayed_lr = decayed.__call__(global_step); | |||
| return decayed_lr; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -27,6 +27,15 @@ namespace Tensorflow | |||
| .ToArray(); | |||
| } | |||
| /// <summary> | |||
| /// Returns an Op that initializes a list of variables. | |||
| /// </summary> | |||
| /// <param name="var_list">List of `Variable` objects to initialize.</param> | |||
| /// <param name="name">Optional name for the returned operation.</param> | |||
| /// <returns>An Op that run the initializers of all the specified variables.</returns> | |||
| public Operation variables_initializer(VariableV1[] var_list, string name = "init") | |||
| => variables.variables_initializer(var_list, name: name); | |||
| public Operation global_variables_initializer() | |||
| { | |||
| var g = variables.global_variables(); | |||
| @@ -115,6 +115,7 @@ namespace Tensorflow | |||
| return instance; | |||
| } | |||
| [DebuggerStepThrough] | |||
| [DebuggerNonUserCode()] // with "Just My Code" enabled this lets the debugger break at the origin of the exception | |||
| public static void tf_with(IObjectLife py, Action<IObjectLife> action) | |||
| { | |||
| @@ -273,7 +274,10 @@ namespace Tensorflow | |||
| return sum; | |||
| } | |||
| public static double sum(IEnumerable<int> enumerable) | |||
| public static float sum(IEnumerable<float> enumerable) | |||
| => enumerable.Sum(); | |||
| public static int sum(IEnumerable<int> enumerable) | |||
| => enumerable.Sum(); | |||
| public static double sum<TKey, TValue>(Dictionary<TKey, TValue> values) | |||
| @@ -0,0 +1,16 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using System.Threading.Tasks; | |||
| namespace Tensorflow.Keras.Optimizers | |||
| { | |||
| public class LearningRateSchedule | |||
| { | |||
| public LearningRateSchedule() | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,62 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using System.Threading.Tasks; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.Optimizers | |||
| { | |||
| /// <summary> | |||
| /// A LearningRateSchedule that uses a polynomial decay schedule. | |||
| /// </summary> | |||
| public class PolynomialDecay : LearningRateSchedule | |||
| { | |||
| float initial_learning_rate; | |||
| float decay_steps; | |||
| float end_learning_rate; | |||
| float power; | |||
| bool cycle; | |||
| string name; | |||
| public PolynomialDecay(float initial_learning_rate, | |||
| float decay_steps, | |||
| float end_learning_rate = 0.0001f, | |||
| float power = 1.0f, | |||
| bool cycle = false, | |||
| string name = null) : base() | |||
| { | |||
| this.initial_learning_rate = initial_learning_rate; | |||
| this.decay_steps = decay_steps; | |||
| this.end_learning_rate = end_learning_rate; | |||
| this.power = power; | |||
| this.cycle = cycle; | |||
| this.name = name; | |||
| } | |||
| public Tensor __call__(RefVariable step) | |||
| { | |||
| tf_with(ops.name_scope(name ?? "PolynomialDecay"), scope => | |||
| { | |||
| name = scope; | |||
| var initial_learning_rate_tensor = ops.convert_to_tensor(initial_learning_rate, name: "initial_learning_rate"); | |||
| var dtype = initial_learning_rate_tensor.dtype; | |||
| var end_learning_rate_tensor = math_ops.cast(end_learning_rate, dtype); | |||
| var power_tensor = math_ops.cast(power, dtype); | |||
| var global_step_recomp = math_ops.cast(step, dtype); | |||
| var decay_steps_recomp = math_ops.cast(decay_steps, dtype); | |||
| if(cycle) | |||
| { | |||
| throw new NotImplementedException("PolynomialDecay cycle"); | |||
| } | |||
| else | |||
| { | |||
| } | |||
| }); | |||
| throw new NotImplementedException(""); | |||
| } | |||
| } | |||
| } | |||
| @@ -19,8 +19,7 @@ namespace Tensorflow.Operations.Initializers | |||
| public class GlorotUniform : VarianceScaling | |||
| { | |||
| public GlorotUniform(float scale = 1.0f, | |||
| string mode = "fan_avg", | |||
| string distribution = "uniform", | |||
| string mode = "FAN_AVG", | |||
| int? seed = null, | |||
| TF_DataType dtype = TF_DataType.TF_FLOAT) : base(factor: scale, | |||
| mode: mode, | |||
| @@ -36,7 +35,6 @@ namespace Tensorflow.Operations.Initializers | |||
| { | |||
| scale = _scale, | |||
| mode = _mode, | |||
| distribution = _distribution, | |||
| seed = _seed, | |||
| dtype = _dtype | |||
| }; | |||
| @@ -30,6 +30,7 @@ namespace Tensorflow.Operations.Initializers | |||
| protected string _distribution; | |||
| protected int? _seed; | |||
| protected TF_DataType _dtype; | |||
| protected bool _uniform; | |||
| public VarianceScaling(float factor = 2.0f, | |||
| string mode = "FAN_IN", | |||
| @@ -49,31 +50,31 @@ namespace Tensorflow.Operations.Initializers | |||
| _mode = mode; | |||
| _seed = seed; | |||
| _dtype = dtype; | |||
| _uniform = uniform; | |||
| } | |||
| public Tensor call(TensorShape shape, TF_DataType dtype, bool? verify_shape = null) | |||
| { | |||
| float n = 0; | |||
| var (fan_in, fan_out) = _compute_fans(shape); | |||
| if (_mode == "fan_in") | |||
| _scale /= Math.Max(1, fan_in); | |||
| else if (_mode == "fan_out") | |||
| _scale /= Math.Max(1, fan_out); | |||
| else | |||
| _scale /= Math.Max(1, (fan_in + fan_out) / 2); | |||
| if (_mode == "FAN_IN") | |||
| n = fan_in; | |||
| else if (_mode == "FAN_OUT") | |||
| n = fan_out; | |||
| else if(_mode == "FAN_AVG") | |||
| n = (fan_in + fan_out) / 2.0f; | |||
| if (_distribution == "normal" || _distribution == "truncated_normal") | |||
| { | |||
| float stddev = (float)Math.Sqrt(_scale) / .87962566103423978f; | |||
| return random_ops.truncated_normal(shape, mean: 0.0f, stddev: stddev, dtype: dtype, seed: _seed); | |||
| } | |||
| else if (_distribution == "untruncated_normal") | |||
| if(_uniform) | |||
| { | |||
| throw new NotImplementedException("truncated_normal"); | |||
| var limit = Convert.ToSingle(Math.Sqrt(3.0f * _scale / n)); | |||
| return random_ops.random_uniform(shape, -limit, limit, | |||
| dtype, seed: _seed); | |||
| } | |||
| else | |||
| { | |||
| var limit = Math.Sqrt(3.0f * _scale); | |||
| return random_ops.random_uniform(shape, (float)-limit, (float)limit, dtype, seed: _seed); | |||
| var trunc_stddev = Convert.ToSingle(Math.Sqrt(1.3f * _scale / n)); | |||
| return random_ops.truncated_normal(shape, 0.0f, trunc_stddev, dtype, | |||
| seed: _seed); | |||
| } | |||
| } | |||
| @@ -106,6 +107,7 @@ namespace Tensorflow.Operations.Initializers | |||
| mode = _mode, | |||
| distribution = _distribution, | |||
| seed = _seed, | |||
| uniform = _uniform, | |||
| dtype = _dtype | |||
| }; | |||
| } | |||
| @@ -383,7 +383,7 @@ namespace Tensorflow | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("StopGradient", name, args: new { input = x, name }); | |||
| return _op.outputs[0]; | |||
| return _op.output; | |||
| } | |||
| public static Tensor strided_slice(Tensor input, Tensor begin, Tensor end, Tensor strides, | |||
| @@ -115,7 +115,7 @@ namespace Tensorflow | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("Mean", name, args: new { input, reduction_indices = axis, keep_dims = keep_dims }); | |||
| return _op.outputs[0]; | |||
| return _op.output; | |||
| } | |||
| public static Tensor prod<T1, T2>(T1 input, T2 axis, bool keep_dims = false, string name = null) | |||
| @@ -98,7 +98,8 @@ namespace Tensorflow | |||
| /// <param name="seed2"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0, string name = null) | |||
| public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0, | |||
| string name = null) | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("RandomShuffle", | |||
| name: name, | |||
| @@ -116,7 +117,8 @@ namespace Tensorflow | |||
| /// <param name="seed2"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public static Tensor truncated_normal(Tensor shape, TF_DataType dtype, int? seed = 0, int? seed2 = 0, string name = null) | |||
| public static Tensor truncated_normal(Tensor shape, TF_DataType dtype, int? seed = 0, | |||
| int? seed2 = 0, string name = null) | |||
| { | |||
| if (!seed.HasValue) | |||
| seed = 0; | |||
| @@ -127,7 +129,24 @@ namespace Tensorflow | |||
| name: name, | |||
| args: new { shape, dtype, seed, seed2 }); | |||
| return _op.outputs[0]; | |||
| return _op.output; | |||
| } | |||
| public static Tensor multinomial(Tensor logits, int num_samples, int? seed = 0, | |||
| int? seed2 = 0, TF_DataType output_dtype = TF_DataType.TF_INT64, string name = null) | |||
| { | |||
| if (!seed.HasValue) | |||
| seed = 0; | |||
| if (!seed2.HasValue) | |||
| seed2 = 0; | |||
| if (output_dtype == TF_DataType.DtInvalid) | |||
| output_dtype = TF_DataType.TF_INT64; | |||
| var _op = _op_def_lib._apply_op_helper("Multinomial", | |||
| name: name, | |||
| args: new { logits, num_samples, seed, seed2, output_dtype }); | |||
| return _op.output; | |||
| } | |||
| } | |||
| } | |||
| @@ -81,6 +81,21 @@ namespace Tensorflow | |||
| }); | |||
| } | |||
| public static Tensor cast(float x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) | |||
| { | |||
| var base_type = dtype.as_base_dtype(); | |||
| return tf_with(ops.name_scope(name, "Cast", new { x }), scope => | |||
| { | |||
| name = scope; | |||
| var x_tensor = ops.convert_to_tensor(x, name: "x"); | |||
| if (x_tensor.dtype.as_base_dtype() != base_type) | |||
| x_tensor = gen_math_ops.cast(x_tensor, base_type, name: name); | |||
| return x_tensor; | |||
| }); | |||
| } | |||
| public static Tensor cumsum<T>(Tensor x, T axis = default, bool exclusive = false, bool reverse = false, string name = null) | |||
| { | |||
| return tf_with(ops.name_scope(name, "Cumsum", new {x}), scope => | |||
| @@ -204,6 +219,12 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| public static Tensor reduce_mean(Tensor[] input_tensors, int axis, bool keepdims = false, string name = null) | |||
| { | |||
| var m = gen_math_ops.mean(input_tensors, axis, keepdims, name); | |||
| return _may_reduce_to_scalar(keepdims, axis, m); | |||
| } | |||
| /// <summary> | |||
| /// Computes the product of elements across dimensions of a tensor. | |||
| /// </summary> | |||
| @@ -142,6 +142,35 @@ namespace Tensorflow | |||
| { | |||
| return ops.convert_to_tensor(shape, name: "shape"); | |||
| } | |||
| public static Tensor multinomial(Tensor logits, int num_samples, int? seed = null, | |||
| string name = null, TF_DataType output_dtype = TF_DataType.DtInvalid) | |||
| { | |||
| return tf_with(ops.name_scope(name, "multinomial", new { logits }), delegate | |||
| { | |||
| return multinomial_categorical_impl(logits, num_samples, output_dtype, seed); | |||
| }); | |||
| } | |||
| /// <summary> | |||
| /// Implementation for random.categorical (v1) and random.categorical (v2). | |||
| /// </summary> | |||
| /// <param name="logits"></param> | |||
| /// <param name="num_samples"></param> | |||
| /// <param name="output_dtype"></param> | |||
| /// <param name="seed"></param> | |||
| /// <returns></returns> | |||
| private static Tensor multinomial_categorical_impl(Tensor logits, int num_samples, TF_DataType dtype = TF_DataType.DtInvalid, | |||
| int? seed = null) | |||
| { | |||
| logits = ops.convert_to_tensor(logits, name: "logits"); | |||
| var (seed1, seed2) = random_seed.get_seed(seed); | |||
| return gen_random_ops.multinomial(logits, | |||
| num_samples, | |||
| seed: seed1, | |||
| seed2: seed2, | |||
| output_dtype: dtype); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,29 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using System.Threading.Tasks; | |||
| namespace Tensorflow.Training | |||
| { | |||
| public class learning_rate_decay | |||
| { | |||
| /// <summary> | |||
| /// Applies a polynomial decay to the learning rate. | |||
| /// </summary> | |||
| /// <param name="learning_rate"></param> | |||
| /// <param name="global_step"></param> | |||
| /// <param name="decay_steps"></param> | |||
| /// <param name="end_learning_rate"></param> | |||
| /// <param name="power"></param> | |||
| /// <param name="cycle"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public static Tensor polynomial_decay(float learning_rate, RefVariable global_step, float decay_steps, | |||
| float end_learning_rate = 0.0001f, float power = 1.0f, bool cycle = false, | |||
| string name = null) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| } | |||
| } | |||
| } | |||