| @@ -1,6 +1,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using NumSharp; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| @@ -18,6 +19,22 @@ namespace Tensorflow | |||
| var tensor = tf.constant(3112.0f); | |||
| } | |||
| }; | |||
| public Action<int> Constant2x3 | |||
| => (iterate) => | |||
| { | |||
| var nd = np.array(new byte[,] | |||
| { | |||
| {1, 2, 3}, | |||
| {4, 5, 6} | |||
| }); | |||
| for (int i = 0; i < iterate; i++) | |||
| { | |||
| var tensor = tf.constant(nd); | |||
| var data = tensor.numpy(); | |||
| } | |||
| }; | |||
| public Action<int> Variable | |||
| => (iterate) => | |||
| { | |||
| @@ -15,6 +15,9 @@ namespace Tensorflow | |||
| int batchSize = 1000; | |||
| // explaination of constant | |||
| mm.Execute(10, 100 * batchSize, cases.Constant2x3); | |||
| // 1 million float tensor 68M. | |||
| mm.Execute(10, 100 * batchSize, cases.Constant); | |||
| @@ -43,7 +43,7 @@ namespace Tensorflow | |||
| /// </summary> | |||
| public partial class c_api | |||
| { | |||
| public const string TensorFlowLibName = "tensorflow"; | |||
| public const string TensorFlowLibName = @"C:\Users\haipi\Documents\Projects\tensorflow\bazel-bin\tensorflow\tensorflow"; | |||
| public static string StringPiece(IntPtr handle) | |||
| { | |||
| @@ -70,8 +70,8 @@ namespace Tensorflow.Eager | |||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||
| { | |||
| base.DisposeUnmanagedResources(handle); | |||
| //print($"deleting DeleteTensorHandle {Id} {_handle.ToString("x16")}"); | |||
| c_api.TF_DeleteTensor(_handle); | |||
| } | |||
| } | |||
| } | |||
| @@ -311,6 +311,7 @@ namespace Tensorflow | |||
| while (queue.Count > 0) | |||
| { | |||
| var op = queue.Dequeue(); | |||
| if (reached_ops.Contains(op)) | |||
| { | |||
| between_ops.Add(op); | |||
| @@ -278,7 +278,11 @@ namespace Tensorflow | |||
| // after removing the trailing '/'. | |||
| name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name); | |||
| var node_def = ops._NodeDef(op_type, name, attrs: attrs); | |||
| if (name == "rnn/while/basic_rnn_cell/MatMul" | |||
| || name == "rnn/while/basic_rnn_cell/MatMul/Enter") | |||
| { | |||
| } | |||
| var input_ops = inputs.Select(x => x.op).ToArray(); | |||
| var control_inputs = _control_dependencies_for_inputs(input_ops); | |||
| @@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine | |||
| _channels_first = args.DataFormat == "channels_first"; | |||
| } | |||
| protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||
| protected override Tensor call(Tensor inputs, bool is_training = false) | |||
| { | |||
| if (_channels_first) | |||
| { | |||
| @@ -121,7 +121,7 @@ namespace Tensorflow.Keras.Engine | |||
| /// <param name="input"></param> | |||
| /// <param name="is_training"></param> | |||
| /// <returns></returns> | |||
| public Tensor Apply(Tensor inputs, bool is_training = false, Tensor state = null) | |||
| public Tensor Apply(Tensor inputs, bool is_training = false) | |||
| { | |||
| Tensor outputs = null; | |||
| @@ -148,7 +148,7 @@ namespace Tensorflow.Keras.Engine | |||
| if (!built) | |||
| MaybeBuild(inputs); | |||
| outputs = call(inputs, is_training: is_training, state: state); | |||
| outputs = call(inputs, is_training: is_training); | |||
| outputs = _set_connectivity_metadata_(inputs, outputs); | |||
| _handle_activity_regularization(inputs, outputs); | |||
| @@ -161,6 +161,35 @@ namespace Tensorflow.Keras.Engine | |||
| return outputs; | |||
| } | |||
| public Tensor[] Apply(Tensor[] inputs, Tensor state, bool is_training = false) | |||
| { | |||
| Tensor[] outputs = null; | |||
| callContext = callContext ?? new ThreadLocal<CallContext>() | |||
| { | |||
| Value = new CallContext() | |||
| }; | |||
| var eager = tf.executing_eagerly(); | |||
| using var ctxManager = CallContext.enter(); | |||
| string nameScope = ""; | |||
| if (eager) | |||
| nameScope = name; | |||
| else | |||
| nameScope = _name_scope(); | |||
| tf_with(ops.name_scope(nameScope), scope => | |||
| { | |||
| if (!built) | |||
| MaybeBuild(inputs[0]); | |||
| outputs = call(inputs, is_training: is_training, state: state); | |||
| }); | |||
| return outputs; | |||
| } | |||
| private Tensor _set_connectivity_metadata_(Tensor inputs, Tensor outputs) | |||
| { | |||
| /*var returnOutputs = new List<Tensor>(); | |||
| @@ -200,7 +229,12 @@ namespace Tensorflow.Keras.Engine | |||
| return null; | |||
| } | |||
| protected virtual Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||
| protected virtual Tensor call(Tensor inputs, bool is_training = false) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| } | |||
| protected virtual Tensor[] call(Tensor[] inputs, Tensor state, bool is_training = false) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| } | |||
| @@ -143,7 +143,7 @@ namespace Tensorflow.Keras.Layers | |||
| built = true; | |||
| } | |||
| protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||
| protected override Tensor call(Tensor inputs, bool is_training = false) | |||
| { | |||
| Tensor outputs = null; | |||
| @@ -95,7 +95,7 @@ namespace Tensorflow.Keras.Layers | |||
| built = true; | |||
| } | |||
| protected override Tensor call(Tensor inputs, bool training = false, Tensor state = null) | |||
| protected override Tensor call(Tensor inputs, bool training = false) | |||
| { | |||
| var outputs = _convolution_op.__call__(inputs, kernel); | |||
| if (use_bias) | |||
| @@ -65,7 +65,7 @@ namespace Tensorflow.Keras.Layers | |||
| built = true; | |||
| } | |||
| protected override Tensor call(Tensor inputs, bool training = false, Tensor state = null) | |||
| protected override Tensor call(Tensor inputs, bool training = false) | |||
| { | |||
| Tensor outputs = null; | |||
| var rank = inputs.rank; | |||
| @@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Layers | |||
| this.args = args; | |||
| } | |||
| protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||
| protected override Tensor call(Tensor inputs, bool is_training = false) | |||
| { | |||
| var output = tf_utils.smart_cond(is_training, | |||
| () => tf.nn.dropout(inputs, | |||
| @@ -62,7 +62,7 @@ namespace Tensorflow.Keras.Layers | |||
| built = true; | |||
| } | |||
| protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||
| protected override Tensor call(Tensor inputs, bool is_training = false) | |||
| { | |||
| var dtype = inputs.dtype; | |||
| if (dtype != tf.int32 && dtype != tf.int64) | |||
| @@ -29,9 +29,9 @@ namespace Tensorflow.Keras.Layers | |||
| .ToArray(); | |||
| } | |||
| protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||
| protected override Tensor call(Tensor inputs, bool is_training = false) | |||
| { | |||
| return base.call(inputs, is_training, state); | |||
| return base.call(inputs, is_training); | |||
| } | |||
| } | |||
| } | |||
| @@ -36,7 +36,7 @@ namespace Tensorflow.Keras.Layers | |||
| input_spec = new InputSpec(ndim: 4); | |||
| } | |||
| protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||
| protected override Tensor call(Tensor inputs, bool is_training = false) | |||
| { | |||
| int[] pool_shape; | |||
| int[] strides; | |||
| @@ -20,7 +20,7 @@ namespace Tensorflow.Keras.Layers | |||
| this.args = args; | |||
| } | |||
| protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||
| protected override Tensor call(Tensor inputs, bool is_training = false) | |||
| { | |||
| scale = math_ops.cast(args.Scale, args.DType); | |||
| offset = math_ops.cast(args.Offset, args.DType); | |||
| @@ -61,9 +61,8 @@ namespace Tensorflow.Layers | |||
| return (results[0], results[1]); | |||
| } | |||
| public Tensor[] __call__(Tensor inputs, | |||
| public Tensor __call__(Tensor inputs, | |||
| Tensor training = null, | |||
| Tensor state = null, | |||
| VariableScope scope = null) | |||
| { | |||
| _set_scope(scope); | |||
| @@ -88,16 +87,54 @@ namespace Tensorflow.Layers | |||
| { | |||
| _current_scope = scope2; | |||
| // Actually call layer | |||
| outputs = base.Apply(inputs, | |||
| is_training: training == null ? false : false, | |||
| state: state); | |||
| outputs = base.Apply(inputs[0], | |||
| is_training: training == null ? false : false); | |||
| }); | |||
| // Update global default collections. | |||
| _add_elements_to_collection(updates.ToArray(), new string[] { tf.GraphKeys.UPDATE_OPS }); | |||
| return outputs; | |||
| } | |||
| public Tensor[] __call__(Tensor[] inputs, | |||
| Tensor state = null, | |||
| Tensor training = null, | |||
| VariableScope scope = null) | |||
| { | |||
| _set_scope(scope); | |||
| _graph = ops._get_graph_from_inputs(inputs, graph: _graph); | |||
| variable_scope scope_context_manager = null; | |||
| if (built) | |||
| { | |||
| scope_context_manager = tf.variable_scope(_scope, | |||
| reuse: true, | |||
| auxiliary_name_scope: false); | |||
| } | |||
| else | |||
| { | |||
| scope_context_manager = tf.variable_scope(_scope, | |||
| reuse: _reuse, | |||
| auxiliary_name_scope: false); | |||
| } | |||
| Tensor[] outputs = null; | |||
| tf_with(scope_context_manager, scope2 => | |||
| { | |||
| _current_scope = scope2; | |||
| // Actually call layer | |||
| outputs = base.Apply(inputs, | |||
| state, | |||
| is_training: training == null ? false : false); | |||
| }); | |||
| // Update global default collections. | |||
| _add_elements_to_collection(updates.ToArray(), new string[] { tf.GraphKeys.UPDATE_OPS }); | |||
| return new Tensor[] { outputs }; | |||
| return outputs; | |||
| } | |||
| protected virtual void _add_elements_to_collection(Operation[] elements, string[] collection_list) | |||
| @@ -326,7 +326,7 @@ namespace Tensorflow.Operations | |||
| protected override void _AddOpInternal(Operation op) | |||
| { | |||
| if (op.name == "gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad") | |||
| if (op.name == "rnn/basic_rnn_cell/kernel/Initializer/random_uniform/shape") | |||
| { | |||
| } | |||
| @@ -61,7 +61,7 @@ namespace Tensorflow | |||
| built = true; | |||
| } | |||
| public Tensor[] __call__(Tensor inputs, LSTMStateTuple state) | |||
| public Tensor __call__(Tensor inputs, LSTMStateTuple state) | |||
| { | |||
| _state = state; | |||
| return base.__call__(inputs); | |||
| @@ -74,7 +74,7 @@ namespace Tensorflow | |||
| /// <param name="training"></param> | |||
| /// <param name="state"></param> | |||
| /// <returns></returns> | |||
| protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||
| protected override Tensor call(Tensor inputs, bool is_training = false) | |||
| { | |||
| var one = constant_op.constant(1, dtype: dtypes.int32); | |||
| // Parameters of gates are concatenated into one multiply for efficiency. | |||
| @@ -67,14 +67,14 @@ namespace Tensorflow | |||
| built = true; | |||
| } | |||
| protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||
| protected override Tensor[] call(Tensor[] inputs, Tensor state, bool is_training = false) | |||
| { | |||
| // Most basic RNN: output = new_state = act(W * input + U * state + B). | |||
| var concat = array_ops.concat(new[] { inputs, state }, 1); | |||
| var concat = array_ops.concat(new[] { inputs[0], state }, 1); | |||
| var gate_inputs = math_ops.matmul(concat, _kernel.AsTensor()); | |||
| gate_inputs = nn_ops.bias_add(gate_inputs, _bias.AsTensor()); | |||
| var output = _activation(gate_inputs, null); | |||
| return output; | |||
| return new[] { output, output }; | |||
| } | |||
| } | |||
| } | |||
| @@ -364,7 +364,7 @@ namespace Tensorflow.Operations | |||
| if (sequence_length != null) | |||
| throw new NotImplementedException("sequence_length != null"); | |||
| else | |||
| outputs = cell.__call__(input_t_t, state: state1); | |||
| outputs = cell.__call__(new[] { input_t_t }, state: state1); | |||
| var (output, new_state) = (outputs[0], outputs[1]); | |||
| // Keras cells always wrap state as list, even if it's a single tensor. | |||
| @@ -326,7 +326,7 @@ namespace Tensorflow | |||
| // the updated inputs are reloaded from the c_api | |||
| lock (Locks.ProcessWide) | |||
| { | |||
| // c_api.UpdateEdge(_graph, output, input, tf.Status.Handle); | |||
| c_api.UpdateEdge(_graph, output, input, tf.Status.Handle); | |||
| //var updated_inputs = inputs; | |||
| tf.Status.Check(); | |||
| } | |||
| @@ -5,7 +5,7 @@ | |||
| <AssemblyName>TensorFlow.NET</AssemblyName> | |||
| <RootNamespace>Tensorflow</RootNamespace> | |||
| <TargetTensorFlow>2.2.0</TargetTensorFlow> | |||
| <Version>0.20.0</Version> | |||
| <Version>0.20.1</Version> | |||
| <LangVersion>8.0</LangVersion> | |||
| <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | |||
| <Company>SciSharp STACK</Company> | |||
| @@ -19,13 +19,13 @@ | |||
| <Description>Google's TensorFlow full binding in .NET Standard. | |||
| Building, training and infering deep learning models. | |||
| https://tensorflownet.readthedocs.io</Description> | |||
| <AssemblyVersion>0.20.0.0</AssemblyVersion> | |||
| <AssemblyVersion>0.20.1.0</AssemblyVersion> | |||
| <PackageReleaseNotes>tf.net 0.20.x and above are based on tensorflow native 2.x. | |||
| * Eager Mode is added finally. | |||
| * tf.keras is partially working. | |||
| * tf.data is added.</PackageReleaseNotes> | |||
| <FileVersion>0.20.0.0</FileVersion> | |||
| <FileVersion>0.20.1.0</FileVersion> | |||
| <PackageLicenseFile>LICENSE</PackageLicenseFile> | |||
| <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | |||
| <SignAssembly>true</SignAssembly> | |||
| @@ -50,6 +50,8 @@ namespace Tensorflow | |||
| /// </summary> | |||
| public AllocationType AllocationType { get; protected set; } | |||
| public IntPtr TensorDataPointer => TF_TensorData(_handle); | |||
| /// <summary> | |||
| /// Create a Tensor object from an existing TF handle | |||
| /// </summary> | |||
| @@ -261,7 +261,6 @@ namespace Tensorflow | |||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||
| { | |||
| c_api.TF_DeleteTensor(handle); | |||
| if (AllocationHandle == null) | |||
| return; | |||
| @@ -88,80 +88,83 @@ namespace Tensorflow | |||
| if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES)) | |||
| collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); | |||
| ops.init_scope(); | |||
| _in_graph_mode = !tf.Context.executing_eagerly(); | |||
| tf_with(ops.name_scope(name, "Variable"), scope => | |||
| tf_with(ops.init_scope2(), delegate | |||
| { | |||
| name = scope; | |||
| var handle_name = ops.name_from_scope_name(name); | |||
| string unique_id = ""; | |||
| string shared_name = ""; | |||
| if (_in_graph_mode) | |||
| { | |||
| shared_name = handle_name; | |||
| unique_id = shared_name; | |||
| } | |||
| else | |||
| var values = init_from_fn ? new object[0] : new object[] { initial_value }; | |||
| tf_with(ops.name_scope(name, "Variable", values), scope => | |||
| { | |||
| unique_id = $"{handle_name}_{ops.uid()}"; | |||
| shared_name = tf.Context.shared_name(); | |||
| } | |||
| var attr = new AttrValue(); | |||
| attr.List = new AttrValue.Types.ListValue(); | |||
| attr.List.S.Add(ByteString.CopyFromUtf8($"loc:@{handle_name}")); | |||
| tf_with(ops.name_scope("Initializer"), delegate | |||
| { | |||
| if (initial_value.GetType().GetInterface("IInitializer") != null) | |||
| initial_value = ops.convert_to_tensor((initial_value as IInitializer).Apply(new InitializerArgs(shape, dtype: dtype))); | |||
| name = scope; | |||
| var handle_name = ops.name_from_scope_name(name); | |||
| string unique_id = ""; | |||
| string shared_name = ""; | |||
| if (_in_graph_mode) | |||
| { | |||
| shared_name = handle_name; | |||
| unique_id = shared_name; | |||
| } | |||
| else | |||
| { | |||
| var value = init_from_fn ? (initial_value as Func<Tensor>)() : initial_value; | |||
| initial_value = ops.convert_to_tensor(value, | |||
| name: "initial_value", | |||
| dtype: dtype); | |||
| unique_id = $"{handle_name}_{ops.uid()}"; | |||
| shared_name = tf.Context.shared_name(); | |||
| } | |||
| }); | |||
| _shape = shape ?? (initial_value as Tensor).TensorShape; | |||
| _initial_value = initial_value as Tensor; | |||
| var attr = new AttrValue(); | |||
| attr.List = new AttrValue.Types.ListValue(); | |||
| attr.List.S.Add(ByteString.CopyFromUtf8($"loc:@{handle_name}")); | |||
| tf_with(ops.name_scope("Initializer"), delegate | |||
| { | |||
| if (initial_value.GetType().GetInterface("IInitializer") != null) | |||
| initial_value = ops.convert_to_tensor((initial_value as IInitializer).Apply(new InitializerArgs(shape, dtype: dtype))); | |||
| else | |||
| { | |||
| var value = init_from_fn ? (initial_value as Func<Tensor>)() : initial_value; | |||
| initial_value = ops.convert_to_tensor(value, | |||
| name: "initial_value", | |||
| dtype: dtype); | |||
| } | |||
| }); | |||
| _shape = shape ?? (initial_value as Tensor).TensorShape; | |||
| _initial_value = initial_value as Tensor; | |||
| if (_in_graph_mode) | |||
| { | |||
| handle = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); | |||
| initializer_op = gen_state_ops.assign(handle, _initial_value, true).op; | |||
| if (_in_graph_mode) | |||
| { | |||
| handle = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); | |||
| initializer_op = gen_state_ops.assign(handle, _initial_value, true).op; | |||
| ops.colocate_with(initializer_op); | |||
| ops.colocate_with(initializer_op); | |||
| _graph_element = gen_array_ops.identity(handle, name = "read"); | |||
| ops.add_to_collections<IVariableV1>(collections, this); | |||
| _dtype = handle.dtype; | |||
| } | |||
| else | |||
| { | |||
| handle = resource_variable_ops.eager_safe_variable_handle( | |||
| initial_value: _initial_value, | |||
| shape: _shape, | |||
| shared_name: shared_name, | |||
| name: name, | |||
| graph_mode: _in_graph_mode); | |||
| gen_resource_variable_ops.assign_variable_op(handle, _initial_value); | |||
| is_initialized_op = null; | |||
| initializer_op = null; | |||
| _graph_element = null; | |||
| _dtype = _initial_value.dtype.as_base_dtype(); | |||
| initial_value = _in_graph_mode ? initial_value : null; | |||
| } | |||
| _graph_element = gen_array_ops.identity(handle, name = "read"); | |||
| ops.add_to_collections<IVariableV1>(collections, this); | |||
| _dtype = handle.dtype; | |||
| } | |||
| else | |||
| { | |||
| handle = resource_variable_ops.eager_safe_variable_handle( | |||
| initial_value: _initial_value, | |||
| shape: _shape, | |||
| shared_name: shared_name, | |||
| name: name, | |||
| graph_mode: _in_graph_mode); | |||
| gen_resource_variable_ops.assign_variable_op(handle, _initial_value); | |||
| is_initialized_op = null; | |||
| initializer_op = null; | |||
| _graph_element = null; | |||
| _dtype = _initial_value.dtype.as_base_dtype(); | |||
| initial_value = _in_graph_mode ? initial_value : null; | |||
| } | |||
| base.__init__(trainable: trainable, | |||
| handle: handle, | |||
| name: name, | |||
| unique_id: unique_id, | |||
| handle_name: handle_name); | |||
| base.__init__(trainable: trainable, | |||
| handle: handle, | |||
| name: name, | |||
| unique_id: unique_id, | |||
| handle_name: handle_name); | |||
| }); | |||
| }); | |||
| } | |||
| @@ -30,6 +30,7 @@ | |||
| <ItemGroup> | |||
| <PackageReference Include="BenchmarkDotNet" Version="0.12.1" /> | |||
| <PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.0" /> | |||
| <PackageReference Include="TensorFlow.NET" Version="0.20.0" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| @@ -43,7 +43,7 @@ | |||
| <ItemGroup> | |||
| <PackageReference Include="FluentAssertions" Version="5.10.3" /> | |||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.7.0" /> | |||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.7.1" /> | |||
| <PackageReference Include="MSTest.TestAdapter" Version="2.1.2" /> | |||
| <PackageReference Include="MSTest.TestFramework" Version="2.1.2" /> | |||
| <PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.0" /> | |||