| @@ -193,8 +193,7 @@ namespace Tensorflow | |||||
| Name = name | Name = name | ||||
| }); | }); | ||||
| throw new NotImplementedException(""); | |||||
| //return layer.apply(inputs).Item1; | |||||
| return layer.Apply(inputs); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -66,8 +66,8 @@ namespace Tensorflow | |||||
| Tensor keep = null; | Tensor keep = null; | ||||
| if (keep_prob != null) | if (keep_prob != null) | ||||
| keep = 1.0f - keep_prob; | keep = 1.0f - keep_prob; | ||||
| return nn_ops.dropout_v2(x, rate: rate.Value, noise_shape: noise_shape, seed: seed, name: name); | |||||
| var rate_tensor = rate.HasValue ? tf.constant(rate.Value) : keep; | |||||
| return nn_ops.dropout_v2(x, rate: rate_tensor, noise_shape: noise_shape, seed: seed, name: name); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -150,7 +150,7 @@ namespace Tensorflow | |||||
| var variables = graph.get_collection<IVariableV1>(tf.GraphKeys.GLOBAL_VARIABLES, | var variables = graph.get_collection<IVariableV1>(tf.GraphKeys.GLOBAL_VARIABLES, | ||||
| scope: scope_to_prepend_to_names); | scope: scope_to_prepend_to_names); | ||||
| var var_list = new Dictionary<string, IVariableV1>(); | var var_list = new Dictionary<string, IVariableV1>(); | ||||
| // variables.ForEach(v => var_list[ops.strip_name_scope(v.Name, scope_to_prepend_to_names)] = v); | |||||
| variables.ForEach(v => var_list[ops.strip_name_scope(v.Name, scope_to_prepend_to_names)] = v); | |||||
| return (var_list, imported_return_elements); | return (var_list, imported_return_elements); | ||||
| } | } | ||||
| @@ -277,6 +277,11 @@ namespace Tensorflow | |||||
| var proto = x_ref_var.to_proto(export_scope); | var proto = x_ref_var.to_proto(export_scope); | ||||
| col_def.BytesList.Value.Add(proto.ToByteString()); | col_def.BytesList.Value.Add(proto.ToByteString()); | ||||
| } | } | ||||
| else if(x is ResourceVariable x_res_var) | |||||
| { | |||||
| var proto = x_res_var.to_proto(export_scope); | |||||
| col_def.BytesList.Value.Add(proto.ToByteString()); | |||||
| } | |||||
| } | } | ||||
| break; | break; | ||||
| case List<RefVariable> collection_list: | case List<RefVariable> collection_list: | ||||
| @@ -31,8 +31,23 @@ namespace Tensorflow | |||||
| /// <param name="output_func_def"></param> | /// <param name="output_func_def"></param> | ||||
| /// <param name="status"></param> | /// <param name="status"></param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_FunctionToFunctionDef(IntPtr func, IntPtr output_func_def, SafeStatusHandle status); | |||||
| public static extern void TF_FunctionToFunctionDef(IntPtr func, SafeBufferHandle output_func_def, SafeStatusHandle status); | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern IntPtr TF_GraphToFunction(IntPtr fn_body, string fn_name, | |||||
| bool append_hash_to_fn_name, | |||||
| int num_opers, IntPtr[] opers, | |||||
| int ninputs, TF_Output[] inputs, | |||||
| int noutputs, TF_Output[] outputs, | |||||
| IntPtr output_names, | |||||
| IntPtr opts, | |||||
| string description, | |||||
| SafeStatusHandle status); | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern IntPtr TF_FunctionName(IntPtr func); | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TF_GraphCopyFunction(IntPtr g, IntPtr func, IntPtr grad, SafeStatusHandle status); | |||||
| } | } | ||||
| } | } | ||||
| @@ -327,8 +327,9 @@ namespace Tensorflow.Gradients | |||||
| var output_shape = op.outputs[0]._shape_tuple(); | var output_shape = op.outputs[0]._shape_tuple(); | ||||
| Tensor result, factor_tensor; | Tensor result, factor_tensor; | ||||
| if(input_shape != null && | |||||
| output_shape != null) | |||||
| if(tf.executing_eagerly() | |||||
| && input_shape != null | |||||
| && output_shape != null) | |||||
| { | { | ||||
| var input_size = np.prod(input_shape); | var input_size = np.prod(input_shape); | ||||
| var output_size = np.prod(output_shape); | var output_size = np.prod(output_shape); | ||||
| @@ -339,11 +340,7 @@ namespace Tensorflow.Gradients | |||||
| { | { | ||||
| var input_shape_tensor = array_ops.shape(op.inputs[0]); | var input_shape_tensor = array_ops.shape(op.inputs[0]); | ||||
| var output_shape_tensor = array_ops.shape(op.outputs[0]); | var output_shape_tensor = array_ops.shape(op.outputs[0]); | ||||
| var factor = _safe_shape_div(math_ops.reduce_prod(input_shape_tensor), math_ops.reduce_prod(output_shape_tensor)); | |||||
| throw new NotImplementedException(""); | |||||
| #pragma warning disable CS0162 // Unreachable code detected | |||||
| factor_tensor = null; | |||||
| #pragma warning restore CS0162 // Unreachable code detected | |||||
| factor_tensor = _safe_shape_div(math_ops.reduce_prod(input_shape_tensor), math_ops.reduce_prod(output_shape_tensor)); | |||||
| } | } | ||||
| result = math_ops.truediv(sum_grad, math_ops.cast(factor_tensor, sum_grad.dtype)); | result = math_ops.truediv(sum_grad, math_ops.cast(factor_tensor, sum_grad.dtype)); | ||||
| @@ -128,10 +128,10 @@ namespace Tensorflow.Gradients | |||||
| [RegisterGradient("Conv2D")] | [RegisterGradient("Conv2D")] | ||||
| public static Tensor[] _Conv2DGrad(Operation op, Tensor[] grads) | public static Tensor[] _Conv2DGrad(Operation op, Tensor[] grads) | ||||
| { | { | ||||
| var dilations = op.get_attr<int[]>("dilations"); | |||||
| var strides = op.get_attr<int[]>("strides"); | |||||
| var dilations = op.get_attr_list<int>("dilations"); | |||||
| var strides = op.get_attr_list<int>("strides"); | |||||
| var padding = op.get_attr<string>("padding"); | var padding = op.get_attr<string>("padding"); | ||||
| var explicit_paddings = op.get_attr<int[]>("explicit_paddings"); | |||||
| var explicit_paddings = op.get_attr_list<int>("explicit_paddings"); | |||||
| var use_cudnn_on_gpu = op.get_attr<bool>("use_cudnn_on_gpu"); | var use_cudnn_on_gpu = op.get_attr<bool>("use_cudnn_on_gpu"); | ||||
| var data_format = op.get_attr<string>("data_format"); | var data_format = op.get_attr<string>("data_format"); | ||||
| var shape = gen_array_ops.shape_n(new Tensor[] { op.inputs[0], op.inputs[1] }); | var shape = gen_array_ops.shape_n(new Tensor[] { op.inputs[0], op.inputs[1] }); | ||||
| @@ -287,8 +287,8 @@ namespace Tensorflow.Gradients | |||||
| op.inputs[0], | op.inputs[0], | ||||
| op.outputs[0], | op.outputs[0], | ||||
| grad, | grad, | ||||
| op.get_attr("ksize") as int[], | |||||
| op.get_attr("strides") as int[], | |||||
| op.get_attr_list<int>("ksize"), | |||||
| op.get_attr_list<int>("strides"), | |||||
| padding: op.get_attr("padding").ToString(), | padding: op.get_attr("padding").ToString(), | ||||
| data_format: op.get_attr("data_format").ToString()) | data_format: op.get_attr("data_format").ToString()) | ||||
| }; | }; | ||||
| @@ -293,12 +293,6 @@ namespace Tensorflow | |||||
| _create_op_helper(op, compute_device); | _create_op_helper(op, compute_device); | ||||
| /*Console.Write($"create_op: {op_type} '{node_def.Name}'"); | |||||
| Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}"); | |||||
| Console.Write($", control_inputs: {(control_inputs.Length == 0 ? "empty" : String.Join(", ", control_inputs.Select(x => x.name)))}"); | |||||
| Console.Write($", outputs: {(op.outputs.Length == 0 ? "empty" : String.Join(", ", op.outputs.Select(x => x.name)))}"); | |||||
| Console.WriteLine();*/ | |||||
| return op; | return op; | ||||
| } | } | ||||
| @@ -139,7 +139,7 @@ namespace Tensorflow | |||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_GraphToGraphDef(IntPtr graph, SafeBufferHandle output_graph_def, SafeStatusHandle status); | public static extern void TF_GraphToGraphDef(IntPtr graph, SafeBufferHandle output_graph_def, SafeStatusHandle status); | ||||
| /// <summary> | /// <summary> | ||||
| /// Returns the number of dimensions of the Tensor referenced by `output` | /// Returns the number of dimensions of the Tensor referenced by `output` | ||||
| /// in `graph`. | /// in `graph`. | ||||
| @@ -0,0 +1,16 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class TensorLikeDataAdapterArgs | |||||
| { | |||||
| public Tensor X { get; set; } | |||||
| public Tensor Y { get; set; } | |||||
| public int BatchSize { get; set; } | |||||
| public int Steps { get; set; } | |||||
| public int Epochs { get; set; } | |||||
| public bool Shuffle { get; set; } | |||||
| } | |||||
| } | |||||
| @@ -27,7 +27,9 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| public DataHandler(DataHandlerArgs args) | public DataHandler(DataHandlerArgs args) | ||||
| { | { | ||||
| this.args = args; | |||||
| var adapter_cls = new TensorLikeDataAdapter(new TensorLikeDataAdapterArgs { }); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,6 +1,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Keras.Engine.DataAdapters | namespace Tensorflow.Keras.Engine.DataAdapters | ||||
| @@ -10,7 +11,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| /// </summary> | /// </summary> | ||||
| public class TensorLikeDataAdapter : IDataAdapter | public class TensorLikeDataAdapter : IDataAdapter | ||||
| { | { | ||||
| public TensorLikeDataAdapter() | |||||
| public TensorLikeDataAdapter(TensorLikeDataAdapterArgs args) | |||||
| { | { | ||||
| tf.data.Dataset.range(5); | tf.data.Dataset.range(5); | ||||
| } | } | ||||
| @@ -0,0 +1,21 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.Engine | |||||
| { | |||||
| public partial class Layer | |||||
| { | |||||
| Dictionary<Layer, object> trainable_state; | |||||
| Dictionary<Layer, object> _get_trainable_state() | |||||
| { | |||||
| trainable_state = new Dictionary<Layer, object>(); | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| void _set_trainable_state(Dictionary<Layer, object> trainable_state) | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,6 +1,7 @@ | |||||
| using NumSharp; | |||||
| using static Tensorflow.Binding; | |||||
| using System; | using System; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine.DataAdapters; | |||||
| using Tensorflow.Keras.Losses; | using Tensorflow.Keras.Losses; | ||||
| using Tensorflow.Keras.Optimizers; | using Tensorflow.Keras.Optimizers; | ||||
| @@ -21,6 +22,7 @@ namespace Tensorflow.Keras.Engine | |||||
| #pragma warning restore CS0108 // Member hides inherited member; missing new keyword | #pragma warning restore CS0108 // Member hides inherited member; missing new keyword | ||||
| string loss; | string loss; | ||||
| IOptimizer optimizer; | IOptimizer optimizer; | ||||
| IVariableV1 _steps_per_execution; | |||||
| public Model(ModelArgs args) | public Model(ModelArgs args) | ||||
| : base(args) | : base(args) | ||||
| @@ -37,10 +39,25 @@ namespace Tensorflow.Keras.Engine | |||||
| break; | break; | ||||
| } | } | ||||
| int experimental_steps_per_execution = 1; | |||||
| _configure_steps_per_execution(experimental_steps_per_execution); | |||||
| _reset_compile_cache(); | |||||
| loss = lossName; | loss = lossName; | ||||
| _is_compiled = true; | _is_compiled = true; | ||||
| } | |||||
| void _configure_steps_per_execution(int steps_per_execution) | |||||
| { | |||||
| _steps_per_execution = tf.Variable(steps_per_execution, | |||||
| dtype: TF_DataType.TF_INT64, | |||||
| aggregation: VariableAggregation.OnlyFirstReplica); | |||||
| } | |||||
| void _reset_compile_cache() | |||||
| { | |||||
| // Prepare list of loss functions, same size of model outputs. | |||||
| } | } | ||||
| public void compile(string optimizerName, ILossFunc lossName) | public void compile(string optimizerName, ILossFunc lossName) | ||||
| @@ -70,6 +87,20 @@ namespace Tensorflow.Keras.Engine | |||||
| int workers = 1, | int workers = 1, | ||||
| bool use_multiprocessing = false) | bool use_multiprocessing = false) | ||||
| { | { | ||||
| var data_handler = new DataHandler(new DataHandlerArgs | |||||
| { | |||||
| X = x, | |||||
| BatchSize = batch_size, | |||||
| StepsPerEpoch = steps, | |||||
| InitialEpoch = 0, | |||||
| Epochs = 1, | |||||
| MaxQueueSize = max_queue_size, | |||||
| Workers = workers, | |||||
| UseMultiprocessing = use_multiprocessing, | |||||
| Model = this, | |||||
| StepsPerExecution = _steps_per_execution | |||||
| }); | |||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| } | } | ||||
| } | } | ||||
| @@ -14,6 +14,7 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Linq; | |||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -44,6 +45,9 @@ namespace Tensorflow.Keras.Layers | |||||
| if (args.InputShape == null) | if (args.InputShape == null) | ||||
| args.InputShape = args.InputLength; | args.InputShape = args.InputLength; | ||||
| if (args.BatchInputShape == null) | |||||
| args.BatchInputShape = new int[] { args.BatchSize }.Concat(args.InputShape.dims).ToArray(); | |||||
| embeddings_initializer = embeddings_initializer ?? tf.random_uniform_initializer; | embeddings_initializer = embeddings_initializer ?? tf.random_uniform_initializer; | ||||
| SupportsMasking = mask_zero; | SupportsMasking = mask_zero; | ||||
| } | } | ||||
| @@ -34,10 +34,13 @@ namespace Tensorflow.Keras.Layers | |||||
| /// <summary> | /// <summary> | ||||
| /// Turns positive integers (indexes) into dense vectors of fixed size. | /// Turns positive integers (indexes) into dense vectors of fixed size. | ||||
| /// This layer can only be used as the first layer in a model. | |||||
| /// e.g. [[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]] | |||||
| /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="input_dim"></param> | |||||
| /// <param name="output_dim"></param> | |||||
| /// <param name="embeddings_initializer"></param> | |||||
| /// <param name="input_dim">Size of the vocabulary, i.e. maximum integer index + 1.</param> | |||||
| /// <param name="output_dim">Dimension of the dense embedding.</param> | |||||
| /// <param name="embeddings_initializer">Initializer for the embeddings matrix (see keras.initializers).</param> | |||||
| /// <param name="mask_zero"></param> | /// <param name="mask_zero"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public Embedding Embedding(int input_dim, | public Embedding Embedding(int input_dim, | ||||
| @@ -36,9 +36,9 @@ namespace Tensorflow.Operations.Initializers | |||||
| public Tensor Apply(InitializerArgs args) | public Tensor Apply(InitializerArgs args) | ||||
| { | { | ||||
| if (args.DType == TF_DataType.DtInvalid) | |||||
| args.DType = this.dtype; | |||||
| return random_ops.truncated_normal(args.Shape, mean, stddev, dtype : dtype, seed: seed); | |||||
| if (args.DType != TF_DataType.DtInvalid) | |||||
| dtype = args.DType; | |||||
| return random_ops.truncated_normal(args.Shape, mean, stddev, dtype: dtype, seed: seed); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -230,6 +230,35 @@ namespace Tensorflow | |||||
| public virtual T get_attr<T>(string name) | public virtual T get_attr<T>(string name) | ||||
| => (T)get_attr(name); | => (T)get_attr(name); | ||||
| public virtual T[] get_attr_list<T>(string name) | |||||
| { | |||||
| if (tf.executing_eagerly()) | |||||
| return (T[])get_attr(name); | |||||
| AttrValue x = null; | |||||
| lock (Locks.ProcessWide) | |||||
| { | |||||
| using var buf = new Buffer(); | |||||
| c_api.TF_OperationGetAttrValueProto(_handle, name, buf.Handle, tf.Status.Handle); | |||||
| tf.Status.Check(true); | |||||
| x = AttrValue.Parser.ParseFrom(buf.DangerousMemoryBlock.Stream()); | |||||
| } | |||||
| string oneof_value = x.ValueCase.ToString(); | |||||
| if (string.IsNullOrEmpty(oneof_value)) | |||||
| return null; | |||||
| switch (typeof(T).Name) | |||||
| { | |||||
| case nameof(Int32): | |||||
| return x.List.I.Select(x => (T)Convert.ChangeType(x, typeof(T))).ToArray(); | |||||
| default: | |||||
| return null; | |||||
| } | |||||
| } | |||||
| public virtual object get_attr(string name) | public virtual object get_attr(string name) | ||||
| { | { | ||||
| AttrValue x = null; | AttrValue x = null; | ||||
| @@ -250,7 +279,7 @@ namespace Tensorflow | |||||
| if (oneof_value == "list") | if (oneof_value == "list") | ||||
| throw new NotImplementedException($"Unsupported field type in {x.ToString()}"); | throw new NotImplementedException($"Unsupported field type in {x.ToString()}"); | ||||
| if (oneof_value == "type") | |||||
| if (string.Equals("type", oneof_value, StringComparison.OrdinalIgnoreCase)) | |||||
| return x.Type; | return x.Type; | ||||
| object result = x.GetType().GetProperty(oneof_value).GetValue(x); | object result = x.GetType().GetProperty(oneof_value).GetValue(x); | ||||
| @@ -85,26 +85,56 @@ namespace Tensorflow | |||||
| allow_broadcast: false); | allow_broadcast: false); | ||||
| public static Tensor zeros(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | public static Tensor zeros(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | ||||
| => tf_with(ops.name_scope(name, "zeros", shape), scope => | |||||
| { | |||||
| dtype = dtype.as_base_dtype(); | |||||
| if (tf.executing_eagerly()) | |||||
| { | { | ||||
| dtype = dtype.as_base_dtype(); | |||||
| name = scope; | |||||
| var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); | |||||
| Tensor zeros = null; | |||||
| switch (dtype) | |||||
| return tf_with(ops.name_scope(name, "zeros", shape), scope => | |||||
| { | { | ||||
| case TF_DataType.TF_DOUBLE: | |||||
| zeros = constant(0d); | |||||
| break; | |||||
| case TF_DataType.TF_FLOAT: | |||||
| zeros = constant(0f); | |||||
| break; | |||||
| default: | |||||
| zeros = constant(0); | |||||
| break; | |||||
| } | |||||
| return fill(shape_tensor, zeros, name: name); | |||||
| }); | |||||
| name = scope; | |||||
| var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); | |||||
| Tensor zeros = null; | |||||
| switch (dtype) | |||||
| { | |||||
| case TF_DataType.TF_DOUBLE: | |||||
| zeros = constant(0d); | |||||
| break; | |||||
| case TF_DataType.TF_FLOAT: | |||||
| zeros = constant(0f); | |||||
| break; | |||||
| default: | |||||
| zeros = constant(0); | |||||
| break; | |||||
| } | |||||
| return fill(shape_tensor, zeros, name: name); | |||||
| }); | |||||
| } | |||||
| else | |||||
| { | |||||
| return tf_with(ops.name_scope(name, "zeros", shape), scope => | |||||
| { | |||||
| name = scope; | |||||
| switch (dtype) | |||||
| { | |||||
| case TF_DataType.TF_BOOL: | |||||
| return _constant_if_small(false, shape, dtype, name); | |||||
| case TF_DataType.TF_DOUBLE: | |||||
| return _constant_if_small(0.0D, shape, dtype, name); | |||||
| case TF_DataType.TF_FLOAT: | |||||
| return _constant_if_small(0.0F, shape, dtype, name); | |||||
| case TF_DataType.TF_INT64: | |||||
| return _constant_if_small(0l, shape, dtype, name); | |||||
| case TF_DataType.TF_INT32: | |||||
| return _constant_if_small(0, shape, dtype, name); | |||||
| case TF_DataType.TF_INT8: | |||||
| return _constant_if_small<byte>(0, shape, dtype, name); | |||||
| default: | |||||
| throw new TypeError("can't find type for zeros"); | |||||
| } | |||||
| }); | |||||
| } | |||||
| } | |||||
| public static Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0) | public static Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0) | ||||
| { | { | ||||
| @@ -68,7 +68,7 @@ namespace Tensorflow | |||||
| /// <param name="seed"></param> | /// <param name="seed"></param> | ||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Tensor dropout_v2(Tensor x, float rate, Tensor noise_shape = null, int? seed = null, string name = null) | |||||
| public static Tensor dropout_v2(Tensor x, Tensor rate, Tensor noise_shape = null, int? seed = null, string name = null) | |||||
| { | { | ||||
| return tf_with(ops.name_scope(name, "dropout", x), scope => | return tf_with(ops.name_scope(name, "dropout", x), scope => | ||||
| { | { | ||||
| @@ -60,17 +60,17 @@ namespace Tensorflow.Train | |||||
| }); | }); | ||||
| } | } | ||||
| public override Operation _apply_dense(Tensor grad, RefVariable var) | |||||
| public override Operation _apply_dense(Tensor grad, ResourceVariable var) | |||||
| { | { | ||||
| var m = get_slot(var, "m"); | var m = get_slot(var, "m"); | ||||
| var v = get_slot(var, "v"); | var v = get_slot(var, "v"); | ||||
| var (beta1_power, beta2_power) = _get_beta_accumulators(); | var (beta1_power, beta2_power) = _get_beta_accumulators(); | ||||
| return gen_training_ops.apply_adam( | return gen_training_ops.apply_adam( | ||||
| var, | |||||
| m, | |||||
| v, | |||||
| math_ops.cast(beta1_power, var.dtype.as_base_dtype()), | |||||
| math_ops.cast(beta2_power, var.dtype.as_base_dtype()), | |||||
| var.Handle, | |||||
| m.Handle, | |||||
| v.Handle, | |||||
| math_ops.cast(beta1_power.Handle, var.dtype.as_base_dtype()), | |||||
| math_ops.cast(beta2_power.Handle, var.dtype.as_base_dtype()), | |||||
| math_ops.cast(_lr_t, var.dtype.as_base_dtype()), | math_ops.cast(_lr_t, var.dtype.as_base_dtype()), | ||||
| math_ops.cast(_beta1_t, var.dtype.as_base_dtype()), | math_ops.cast(_beta1_t, var.dtype.as_base_dtype()), | ||||
| math_ops.cast(_beta2_t, var.dtype.as_base_dtype()), | math_ops.cast(_beta2_t, var.dtype.as_base_dtype()), | ||||
| @@ -278,8 +278,16 @@ namespace Tensorflow | |||||
| public virtual Operation _apply_dense(Tensor grad, ResourceVariable var) | public virtual Operation _apply_dense(Tensor grad, ResourceVariable var) | ||||
| { | { | ||||
| var alpha = math_ops.cast(LearningRateTensor, var.dtype.as_base_dtype()); | |||||
| return gen_training_ops.resource_apply_gradient_descent(var.Handle, alpha, grad, use_locking: _use_locking).op; | |||||
| if (tf.executing_eagerly()) | |||||
| { | |||||
| var alpha = math_ops.cast(LearningRateTensor, var.dtype.as_base_dtype()); | |||||
| return gen_training_ops.resource_apply_gradient_descent(var, alpha, grad, use_locking: _use_locking).op; | |||||
| } | |||||
| else | |||||
| { | |||||
| var alpha = math_ops.cast(LearningRateTensor, var.dtype.as_base_dtype()); | |||||
| return gen_training_ops.apply_gradient_descent(var, alpha, grad, use_locking: _use_locking).op; | |||||
| } | |||||
| } | } | ||||
| public virtual Operation _apply_dense(Tensor grad, RefVariable var) | public virtual Operation _apply_dense(Tensor grad, RefVariable var) | ||||
| @@ -314,6 +322,11 @@ namespace Tensorflow | |||||
| return _apply_sparse(gradient_no_duplicate_indices, var); | return _apply_sparse(gradient_no_duplicate_indices, var); | ||||
| } | } | ||||
| public virtual Operation _apply_sparse(IndexedSlices grad, ResourceVariable var) | |||||
| { | |||||
| throw new NotImplementedException("_apply_sparse"); | |||||
| } | |||||
| public virtual Operation _apply_sparse(IndexedSlices grad, RefVariable var) | public virtual Operation _apply_sparse(IndexedSlices grad, RefVariable var) | ||||
| { | { | ||||
| throw new NotImplementedException("_apply_sparse"); | throw new NotImplementedException("_apply_sparse"); | ||||
| @@ -224,7 +224,7 @@ namespace Tensorflow | |||||
| var saveable_tensors = all_tensors.Skip(idx).Take(saveable.specs.Length); | var saveable_tensors = all_tensors.Skip(idx).Take(saveable.specs.Length); | ||||
| idx += saveable.specs.Length; | idx += saveable.specs.Length; | ||||
| var restored = saveable.restore(saveable_tensors.ToArray(), shapes == null ? null : shapes.ToArray()); | var restored = saveable.restore(saveable_tensors.ToArray(), shapes == null ? null : shapes.ToArray()); | ||||
| assign_ops.Add(restored as ITensorOrOperation); | |||||
| assign_ops.Add(restored); | |||||
| } | } | ||||
| return control_flow_ops.group(assign_ops.ToArray(), name: name); | return control_flow_ops.group(assign_ops.ToArray(), name: name); | ||||
| @@ -13,6 +13,7 @@ | |||||
| See the License for the specific language governing permissions and | See the License for the specific language governing permissions and | ||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -67,9 +67,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| ops.init_scope(); | ops.init_scope(); | ||||
| var variable = ops.internal_convert_to_tensor(op, as_ref: true); | var variable = ops.internal_convert_to_tensor(op, as_ref: true); | ||||
| if (variable.op.type == "Variable" || | |||||
| variable.op.type == "VariableV2" || | |||||
| variable.op.type == "AutoReloadVariable") | |||||
| if (variable.dtype.is_ref_dtype()) | |||||
| yield return new ReferenceVariableSaveable(variable, "", name); | yield return new ReferenceVariableSaveable(variable, "", name); | ||||
| else | else | ||||
| yield return new ResourceVariableSaveable(variable, "", name); | yield return new ResourceVariableSaveable(variable, "", name); | ||||
| @@ -102,7 +100,7 @@ namespace Tensorflow | |||||
| if (convert_variable_to_tensor) | if (convert_variable_to_tensor) | ||||
| { | { | ||||
| if (var is ResourceVariable) | |||||
| if (!var.dtype.is_ref_dtype()) | |||||
| tensor = var.GraphElement; | tensor = var.GraphElement; | ||||
| else | else | ||||
| tensor = ops.internal_convert_to_tensor(var, as_ref: true); | tensor = ops.internal_convert_to_tensor(var, as_ref: true); | ||||
| @@ -41,7 +41,7 @@ namespace Tensorflow | |||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| } | } | ||||
| public static Tensor apply_adam(IVariableV1 var, IVariableV1 m, IVariableV1 v, Tensor beta1_power, Tensor beta2_power, | |||||
| public static Tensor apply_adam(Tensor var, Tensor m, Tensor v, Tensor beta1_power, Tensor beta2_power, | |||||
| Tensor lr, Tensor beta1, Tensor beta2, Tensor epsilon, Tensor grad, | Tensor lr, Tensor beta1, Tensor beta2, Tensor epsilon, Tensor grad, | ||||
| bool use_locking = false, bool use_nesterov = false, string name = null) | bool use_locking = false, bool use_nesterov = false, string name = null) | ||||
| { | { | ||||
| @@ -64,7 +64,7 @@ namespace Tensorflow | |||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| } | } | ||||
| public static Tensor apply_gradient_descent(RefVariable var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null) | |||||
| public static Tensor apply_gradient_descent(IVariableV1 var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null) | |||||
| { | { | ||||
| var _op = tf.OpDefLib._apply_op_helper("ApplyGradientDescent", name, new | var _op = tf.OpDefLib._apply_op_helper("ApplyGradientDescent", name, new | ||||
| { | { | ||||
| @@ -82,7 +82,7 @@ namespace Tensorflow | |||||
| if (tf.executing_eagerly()) | if (tf.executing_eagerly()) | ||||
| { | { | ||||
| var result = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | var result = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | ||||
| "ResourceApplyGradientDescent", name, | |||||
| "ResourceApplyGradientDescent", name, | |||||
| null, | null, | ||||
| var, alpha, delta, | var, alpha, delta, | ||||
| "use_locking", use_locking); | "use_locking", use_locking); | ||||
| @@ -28,6 +28,8 @@ namespace Tensorflow | |||||
| protected Tensor _initial_value; | protected Tensor _initial_value; | ||||
| public Tensor initial_value => _initial_value; | public Tensor initial_value => _initial_value; | ||||
| public Operation initializer => initializer_op; | |||||
| protected Tensor _parent_op; | protected Tensor _parent_op; | ||||
| public Tensor parent_op => _parent_op; | public Tensor parent_op => _parent_op; | ||||
| @@ -73,6 +75,14 @@ namespace Tensorflow | |||||
| public ITensorOrOperation assign<T>(T value, bool use_locking = false, string name = null, bool read_value = true) | public ITensorOrOperation assign<T>(T value, bool use_locking = false, string name = null, bool read_value = true) | ||||
| { | { | ||||
| if(value.GetType() == typeof(Tensor)) | |||||
| { | |||||
| var assign = gen_state_ops.assign(handle, value, use_locking: use_locking, name: name); | |||||
| if (read_value) | |||||
| return assign; | |||||
| return assign.op; | |||||
| } | |||||
| var value_tensor = ops.convert_to_tensor(value, dtype: dtype); | var value_tensor = ops.convert_to_tensor(value, dtype: dtype); | ||||
| var assign_op = gen_resource_variable_ops.assign_variable_op( | var assign_op = gen_resource_variable_ops.assign_variable_op( | ||||
| handle, value_tensor, name: name); | handle, value_tensor, name: name); | ||||
| @@ -82,7 +92,7 @@ namespace Tensorflow | |||||
| return assign_op; | return assign_op; | ||||
| } | } | ||||
| public Tensor value() => _read_variable_op(); | |||||
| public Tensor value() => tf.executing_eagerly() ? _read_variable_op() : GraphElement; | |||||
| protected Tensor _read_variable_op() | protected Tensor _read_variable_op() | ||||
| { | { | ||||
| @@ -149,6 +159,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| } | } | ||||
| public Tensor AsTensor() => read_value(); | |||||
| public Tensor AsTensor() | |||||
| => tf.executing_eagerly() ? read_value() : GraphElement; | |||||
| } | } | ||||
| } | } | ||||
| @@ -33,10 +33,16 @@ namespace Tensorflow | |||||
| { | { | ||||
| public string UniqueId { get; } | public string UniqueId { get; } | ||||
| public string Name { get; } | public string Name { get; } | ||||
| /// <summary> | |||||
| /// Handle is ref type | |||||
| /// </summary> | |||||
| public Tensor Handle { get; } | public Tensor Handle { get; } | ||||
| public string Device { get; } | public string Device { get; } | ||||
| public Operation Initializer { get; } | public Operation Initializer { get; } | ||||
| public Operation Op { get; } | public Operation Op { get; } | ||||
| /// <summary> | |||||
| /// GraphElement is a copy of Handle | |||||
| /// </summary> | |||||
| public Tensor GraphElement { get; } | public Tensor GraphElement { get; } | ||||
| public Graph Graph { get; } | public Graph Graph { get; } | ||||
| public TF_DataType dtype { get; } | public TF_DataType dtype { get; } | ||||
| @@ -1,5 +1,6 @@ | |||||
| using System; | using System; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -21,11 +22,6 @@ namespace Tensorflow | |||||
| public static implicit operator EagerTensor(ResourceVariable var) | public static implicit operator EagerTensor(ResourceVariable var) | ||||
| => var._dense_var_to_tensor() as EagerTensor; | => var._dense_var_to_tensor() as EagerTensor; | ||||
| public static implicit operator RefVariable(ResourceVariable var) | |||||
| { | |||||
| return null; | |||||
| } | |||||
| public static implicit operator IntPtr(ResourceVariable var) | public static implicit operator IntPtr(ResourceVariable var) | ||||
| => var._handle; | => var._handle; | ||||
| @@ -35,5 +31,13 @@ namespace Tensorflow | |||||
| { | { | ||||
| return value(); | return value(); | ||||
| } | } | ||||
| public Tensor _TensorConversionFunction(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false) | |||||
| { | |||||
| if (as_ref) | |||||
| return handle; | |||||
| else | |||||
| return tf.executing_eagerly() ? AsTensor() : value(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -49,6 +49,7 @@ namespace Tensorflow | |||||
| VariableDef variable_def = null, | VariableDef variable_def = null, | ||||
| TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| string import_scope = "", | string import_scope = "", | ||||
| VariableAggregation aggregation = VariableAggregation.None, | |||||
| TensorShape shape = null) | TensorShape shape = null) | ||||
| { | { | ||||
| if (variable_def != null) | if (variable_def != null) | ||||
| @@ -65,6 +66,7 @@ namespace Tensorflow | |||||
| caching_device: caching_device, | caching_device: caching_device, | ||||
| name: name, | name: name, | ||||
| dtype: dtype, | dtype: dtype, | ||||
| aggregation: aggregation, | |||||
| shape: shape); | shape: shape); | ||||
| } | } | ||||
| } | } | ||||
| @@ -75,6 +77,7 @@ namespace Tensorflow | |||||
| string caching_device = "", | string caching_device = "", | ||||
| string name = null, | string name = null, | ||||
| TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| VariableAggregation aggregation = VariableAggregation.None, | |||||
| TensorShape shape = null) | TensorShape shape = null) | ||||
| { | { | ||||
| var init_from_fn = initial_value.GetType().Name == "Func`1" || | var init_from_fn = initial_value.GetType().Name == "Func`1" || | ||||
| @@ -114,55 +117,43 @@ namespace Tensorflow | |||||
| if (initial_value.GetType().GetInterface("IInitializer") != null) | if (initial_value.GetType().GetInterface("IInitializer") != null) | ||||
| initial_value = ops.convert_to_tensor((initial_value as IInitializer).Apply(new InitializerArgs(shape, dtype: dtype))); | initial_value = ops.convert_to_tensor((initial_value as IInitializer).Apply(new InitializerArgs(shape, dtype: dtype))); | ||||
| else | else | ||||
| initial_value = ops.convert_to_tensor(init_from_fn ? (initial_value as Func<Tensor>)() : initial_value, | |||||
| { | |||||
| var value = init_from_fn ? (initial_value as Func<Tensor>)() : initial_value; | |||||
| initial_value = ops.convert_to_tensor(value, | |||||
| name: "initial_value", | name: "initial_value", | ||||
| dtype: dtype); | dtype: dtype); | ||||
| } | |||||
| }); | }); | ||||
| _shape = shape ?? (initial_value as Tensor).TensorShape; | _shape = shape ?? (initial_value as Tensor).TensorShape; | ||||
| _initial_value = initial_value as Tensor; | _initial_value = initial_value as Tensor; | ||||
| handle = resource_variable_ops.eager_safe_variable_handle( | |||||
| initial_value: _initial_value, | |||||
| shape: _shape, | |||||
| shared_name: shared_name, | |||||
| name: name, | |||||
| graph_mode: _in_graph_mode); | |||||
| _dtype = _initial_value.dtype.as_base_dtype(); | |||||
| if (_in_graph_mode) | if (_in_graph_mode) | ||||
| { | { | ||||
| tf_with(ops.name_scope("IsInitialized"), delegate | |||||
| { | |||||
| is_initialized_op = gen_resource_variable_ops.var_is_initialized_op(handle); | |||||
| }); | |||||
| if(initial_value != null) | |||||
| { | |||||
| tf_with(ops.name_scope("Assign"), scope1 => | |||||
| { | |||||
| string n = scope1; | |||||
| var _initial_value2 = variables._try_guard_against_uninitialized_dependencies(name, _initial_value); | |||||
| initializer_op = gen_resource_variable_ops.assign_variable_op(handle, _initial_value2, name: n); | |||||
| }); | |||||
| } | |||||
| handle = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); | |||||
| initializer_op = gen_state_ops.assign(handle, _initial_value, true).op; | |||||
| // Manually assign reads to the handle's device to avoid log | |||||
| // messages. | |||||
| tf_with(ops.name_scope("Read"), delegate | |||||
| { | |||||
| var value = gen_resource_variable_ops.read_variable_op(handle, _dtype); | |||||
| // _maybe_set_handle_data(dtype, handle, value); | |||||
| _graph_element = value; | |||||
| }); | |||||
| ops.colocate_with(initializer_op); | |||||
| _graph_element = gen_array_ops.identity(handle, name = "read"); | |||||
| ops.add_to_collections<IVariableV1>(collections, this); | ops.add_to_collections<IVariableV1>(collections, this); | ||||
| _dtype = handle.dtype; | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| handle = resource_variable_ops.eager_safe_variable_handle( | |||||
| initial_value: _initial_value, | |||||
| shape: _shape, | |||||
| shared_name: shared_name, | |||||
| name: name, | |||||
| graph_mode: _in_graph_mode); | |||||
| gen_resource_variable_ops.assign_variable_op(handle, _initial_value); | gen_resource_variable_ops.assign_variable_op(handle, _initial_value); | ||||
| is_initialized_op = null; | is_initialized_op = null; | ||||
| initializer_op = null; | initializer_op = null; | ||||
| _graph_element = null; | _graph_element = null; | ||||
| _dtype = _initial_value.dtype.as_base_dtype(); | |||||
| initial_value = _in_graph_mode ? initial_value : null; | initial_value = _in_graph_mode ? initial_value : null; | ||||
| } | } | ||||
| @@ -237,5 +228,23 @@ namespace Tensorflow | |||||
| return array_ops.identity(value); | return array_ops.identity(value); | ||||
| }); | }); | ||||
| } | } | ||||
| public VariableDef to_proto(string export_scope) | |||||
| { | |||||
| if (string.IsNullOrEmpty(export_scope) || Handle.name.StartsWith(export_scope)) | |||||
| { | |||||
| var var_def = new VariableDef(); | |||||
| var_def.VariableName = ops.strip_name_scope(Handle.name, export_scope); | |||||
| if (_initial_value != null) | |||||
| var_def.InitialValueName = ops.strip_name_scope(_initial_value.name, export_scope); | |||||
| var_def.Trainable = _trainable; | |||||
| var_def.InitializerName = ops.strip_name_scope(initializer.name, export_scope); | |||||
| var_def.SnapshotName = ops.strip_name_scope(_graph_element.name, export_scope); | |||||
| return var_def; | |||||
| } | |||||
| throw new NotImplementedException("to_proto RefVariable"); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -467,7 +467,7 @@ namespace Tensorflow | |||||
| case RefVariable varVal: | case RefVariable varVal: | ||||
| return varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref); | return varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref); | ||||
| case ResourceVariable varVal: | case ResourceVariable varVal: | ||||
| return varVal.value(); | |||||
| return varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref); | |||||
| case TensorShape ts: | case TensorShape ts: | ||||
| return constant_op.constant(ts.dims, dtype: dtype, name: name); | return constant_op.constant(ts.dims, dtype: dtype, name: name); | ||||
| case int[] dims: | case int[] dims: | ||||
| @@ -70,12 +70,14 @@ namespace Tensorflow | |||||
| bool use_resource = true, | bool use_resource = true, | ||||
| string name = null, | string name = null, | ||||
| TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| VariableAggregation aggregation = VariableAggregation.None, | |||||
| int[] shape = null) | int[] shape = null) | ||||
| => new ResourceVariable(data, | => new ResourceVariable(data, | ||||
| trainable: trainable, | trainable: trainable, | ||||
| validate_shape: validate_shape, | validate_shape: validate_shape, | ||||
| name: name, | name: name, | ||||
| dtype: dtype, | dtype: dtype, | ||||
| aggregation: aggregation, | |||||
| shape: shape); | shape: shape); | ||||
| public Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null) | public Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null) | ||||