| @@ -14,10 +14,15 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using Tensorflow.Gradients; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public partial class tensorflow | public partial class tensorflow | ||||
| { | { | ||||
| public GradientActor GradientTape() | |||||
| => new GradientActor(); | |||||
| public Tensor[] gradients(Tensor[] ys, | public Tensor[] gradients(Tensor[] ys, | ||||
| Tensor[] xs, | Tensor[] xs, | ||||
| Tensor[] grad_ys = null, | Tensor[] grad_ys = null, | ||||
| @@ -0,0 +1,32 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using Tensorflow.Keras; | |||||
| using Tensorflow.Keras.Engine; | |||||
| using Tensorflow.Keras.Optimizers; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public partial class tensorflow | |||||
| { | |||||
| public KerasOptimizers optimizers => new KerasOptimizers(); | |||||
| public class KerasOptimizers | |||||
| { | |||||
| public SGD SGD(float learning_rate) => new SGD(learning_rate); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -41,6 +41,14 @@ namespace Tensorflow | |||||
| public static void append<T>(this IList<T> list, T element) | public static void append<T>(this IList<T> list, T element) | ||||
| => list.Add(element); | => list.Add(element); | ||||
| public static T[] concat<T>(this IList<T> list1, IList<T> list2) | |||||
| { | |||||
| var list = new List<T>(); | |||||
| list.AddRange(list1); | |||||
| list.AddRange(list2); | |||||
| return list.ToArray(); | |||||
| } | |||||
| public static void extend<T>(this List<T> list, IEnumerable<T> elements) | public static void extend<T>(this List<T> list, IEnumerable<T> elements) | ||||
| => list.AddRange(elements); | => list.AddRange(elements); | ||||
| @@ -0,0 +1,14 @@ | |||||
| using NumSharp; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Eager; | |||||
| namespace Tensorflow.Eager | |||||
| { | |||||
| public partial class EagerTensor | |||||
| { | |||||
| public static explicit operator TFE_TensorHandle(EagerTensor tensor) | |||||
| => tensor.tfe_tensor_handle; | |||||
| } | |||||
| } | |||||
| @@ -5,30 +5,39 @@ using System.Text; | |||||
| namespace Tensorflow.Eager | namespace Tensorflow.Eager | ||||
| { | { | ||||
| public class EagerTensor : Tensor | |||||
| public partial class EagerTensor : Tensor | |||||
| { | { | ||||
| Status status = new Status(); | |||||
| TFE_TensorHandle tfe_tensor_handle; | |||||
| public EagerTensor(IntPtr handle) : base(handle) | public EagerTensor(IntPtr handle) : base(handle) | ||||
| { | { | ||||
| tfe_tensor_handle = handle; | |||||
| _handle = c_api.TFE_TensorHandleResolve(handle, status); | |||||
| } | } | ||||
| public EagerTensor(string value, string device_name) : base(value) | public EagerTensor(string value, string device_name) : base(value) | ||||
| { | { | ||||
| tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | |||||
| } | } | ||||
| public EagerTensor(int value, string device_name) : base(value) | public EagerTensor(int value, string device_name) : base(value) | ||||
| { | { | ||||
| tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | |||||
| } | } | ||||
| public EagerTensor(float[] value, string device_name) : base(value) | public EagerTensor(float[] value, string device_name) : base(value) | ||||
| { | { | ||||
| tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | |||||
| } | } | ||||
| public EagerTensor(double[] value, string device_name) : base(value) | public EagerTensor(double[] value, string device_name) : base(value) | ||||
| { | { | ||||
| tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | |||||
| } | } | ||||
| public EagerTensor(NDArray value, string device_name) : base(value) | public EagerTensor(NDArray value, string device_name) : base(value) | ||||
| { | { | ||||
| tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | |||||
| } | } | ||||
| public override string ToString() | public override string ToString() | ||||
| @@ -51,6 +60,8 @@ namespace Tensorflow.Eager | |||||
| { | { | ||||
| case TF_DataType.TF_STRING: | case TF_DataType.TF_STRING: | ||||
| return $"b'{(string)nd}'"; | return $"b'{(string)nd}'"; | ||||
| case TF_DataType.TF_BOOL: | |||||
| return (nd.GetByte(0) > 0).ToString(); | |||||
| default: | default: | ||||
| return nd.ToString(); | return nd.ToString(); | ||||
| } | } | ||||
| @@ -32,31 +32,41 @@ namespace Tensorflow.Eager | |||||
| ctx.ensure_initialized(); | ctx.ensure_initialized(); | ||||
| using (var status = new Status()) | using (var status = new Status()) | ||||
| { | { | ||||
| var retVals = wrap_tfe_src.TFE_Py_Execute(ctx, ctx.device_name, op_name, inputs, attrs, 1, status); | |||||
| var retVals = wrap_tfe_src.TFE_Execute(ctx, ctx.device_name, op_name, inputs, attrs, 1, status); | |||||
| var t = c_api.TFE_TensorHandleResolve(retVals[0], status); | |||||
| status.Check(true); | |||||
| return new EagerTensor(t); | |||||
| return new EagerTensor(retVals[0]); | |||||
| } | } | ||||
| } | } | ||||
| public (TF_DataType, Tensor) args_to_matching_eager(Tensor[] l, Context ctx, TF_DataType default_dtype = TF_DataType.DtInvalid) | |||||
| public (TF_DataType, Tensor[]) args_to_matching_eager(Context ctx, TF_DataType default_dtype = TF_DataType.DtInvalid, object[] args = null) | |||||
| { | { | ||||
| var dtype = default_dtype; | |||||
| if(dtype == TF_DataType.DtInvalid) | |||||
| { | |||||
| var tensor = ops.convert_to_tensor(l, dtype, preferred_dtype: default_dtype, ctx: ctx); | |||||
| if (args.Length == 0 && default_dtype != TF_DataType.DtInvalid) | |||||
| return (default_dtype, null); | |||||
| if (dtype == TF_DataType.DtInvalid) | |||||
| dtype = tensor.dtype; | |||||
| if (args.Count(x => x is EagerTensor) == args.Length) | |||||
| return ((args[0] as EagerTensor).dtype, args.Select(x => x as EagerTensor).ToArray()); | |||||
| return (dtype, tensor); | |||||
| var dtype = TF_DataType.DtInvalid; | |||||
| foreach (var x in args) | |||||
| { | |||||
| if (x is EagerTensor et) | |||||
| dtype = et.dtype; | |||||
| } | } | ||||
| else | |||||
| if (dtype == TF_DataType.DtInvalid) | |||||
| { | { | ||||
| return (dtype, l[0]); | |||||
| var ret = new List<Tensor>(); | |||||
| foreach (var t in args) | |||||
| { | |||||
| ret.Add(ops.convert_to_tensor(t, dtype, preferred_dtype: default_dtype, ctx: ctx)); | |||||
| if (dtype == TF_DataType.DtInvalid) | |||||
| dtype = ret.Last().dtype; | |||||
| } | |||||
| return (dtype, ret.ToArray()); | |||||
| } | } | ||||
| else | |||||
| throw new NotImplementedException(""); | |||||
| } | } | ||||
| public void record_gradient(string op_name, InputList inputs, Dictionary<string, object> attrs, Tensor[] results, string name = null) | public void record_gradient(string op_name, InputList inputs, Dictionary<string, object> attrs, Tensor[] results, string name = null) | ||||
| @@ -101,6 +101,16 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern TFE_Op TFE_NewOp(IntPtr ctx, string op_or_function_name, IntPtr status); | public static extern TFE_Op TFE_NewOp(IntPtr ctx, string op_or_function_name, IntPtr status); | ||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="ctx">TFE_Context*</param> | |||||
| /// <param name="op_or_function_name">const char*</param> | |||||
| /// <param name="status">TF_Status*</param> | |||||
| /// <param name="op_to_reset">TFE_Op*</param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TFE_OpReset(IntPtr ctx, string op_or_function_name, IntPtr status, IntPtr op_to_reset); | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| /// </summary> | /// </summary> | ||||
| @@ -1,7 +1,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System; | using System; | ||||
| using static Tensorflow.OpDef.Types; | |||||
| using Tensorflow.Gradients; | |||||
| namespace Tensorflow.Eager | namespace Tensorflow.Eager | ||||
| { | { | ||||
| @@ -10,16 +10,16 @@ namespace Tensorflow.Eager | |||||
| /// </summary> | /// </summary> | ||||
| public partial class wrap_tfe_src | public partial class wrap_tfe_src | ||||
| { | { | ||||
| public static IntPtr[] TFE_Py_Execute(Context ctx, | |||||
| public static IntPtr[] TFE_Execute(Context ctx, | |||||
| string device_name, | string device_name, | ||||
| string op_name, | string op_name, | ||||
| Tensor[] inputs, | Tensor[] inputs, | ||||
| object[] attrs, | object[] attrs, | ||||
| int num_outputs, | int num_outputs, | ||||
| Status status) | Status status) | ||||
| => TFE_Py_ExecuteCancelable(ctx, device_name, op_name, inputs, attrs, num_outputs, status); | |||||
| => TFE_ExecuteCancelable(ctx, device_name, op_name, inputs, attrs, num_outputs, status); | |||||
| public static IntPtr[] TFE_Py_ExecuteCancelable(Context ctx, | |||||
| public static IntPtr[] TFE_ExecuteCancelable(Context ctx, | |||||
| string device_name, | string device_name, | ||||
| string op_name, | string op_name, | ||||
| Tensor[] inputs, | Tensor[] inputs, | ||||
| @@ -27,14 +27,23 @@ namespace Tensorflow.Eager | |||||
| int num_outputs, | int num_outputs, | ||||
| Status status) | Status status) | ||||
| { | { | ||||
| var op = c_api.TFE_NewOp(ctx, op_name, status); | |||||
| var op = GetOp(ctx, op_name, status); | |||||
| status.Check(true); | status.Check(true); | ||||
| c_api.TFE_OpSetDevice(op, device_name, status); | c_api.TFE_OpSetDevice(op, device_name, status); | ||||
| if(status.ok()) | if(status.ok()) | ||||
| { | { | ||||
| for (int i = 0; i < inputs.Length; ++i) | for (int i = 0; i < inputs.Length; ++i) | ||||
| { | { | ||||
| var tensor_handle = c_api.TFE_NewTensorHandle(inputs[i], status); | |||||
| TFE_TensorHandle tensor_handle; | |||||
| switch (inputs[i]) | |||||
| { | |||||
| case EagerTensor et: | |||||
| tensor_handle = (TFE_TensorHandle)et; | |||||
| break; | |||||
| default: | |||||
| tensor_handle = c_api.TFE_NewTensorHandle(inputs[i], status); | |||||
| break; | |||||
| } | |||||
| c_api.TFE_OpAddInput(op, tensor_handle, status); | c_api.TFE_OpAddInput(op, tensor_handle, status); | ||||
| } | } | ||||
| } | } | ||||
| @@ -22,7 +22,7 @@ namespace Tensorflow.Eager | |||||
| var attr_list_sizes = new Dictionary<string, long>(); | var attr_list_sizes = new Dictionary<string, long>(); | ||||
| using (var status = new Status()) | using (var status = new Status()) | ||||
| { | { | ||||
| var op = c_api.TFE_NewOp(ctx, opName, status); | |||||
| var op = GetOp(ctx, opName, status); | |||||
| var op_def = Graph.TFE_GetOpDef(opName); | var op_def = Graph.TFE_GetOpDef(opName); | ||||
| @@ -101,11 +101,31 @@ namespace Tensorflow.Eager | |||||
| c_api.TFE_Execute(op, retVals, ref num_retvals, status); | c_api.TFE_Execute(op, retVals, ref num_retvals, status); | ||||
| status.Check(true); | status.Check(true); | ||||
| var t = c_api.TFE_TensorHandleResolve(retVals[0], status); | |||||
| status.Check(true); | |||||
| return num_retvals == 0 ? null : new EagerTensor(retVals[0]); | |||||
| } | |||||
| } | |||||
| return new EagerTensor(t); | |||||
| private static TFE_Op GetOp(Context ctx, string op_or_function_name, Status status) | |||||
| { | |||||
| var maybe_op = ReleaseThreadLocalOp(); | |||||
| if (maybe_op != IntPtr.Zero) | |||||
| { | |||||
| c_api.TFE_OpReset(ctx, op_or_function_name, status, maybe_op); | |||||
| } | |||||
| else | |||||
| { | |||||
| maybe_op = c_api.TFE_NewOp(ctx, op_or_function_name, status); | |||||
| op = maybe_op; | |||||
| } | } | ||||
| status.Check(true); | |||||
| return maybe_op; | |||||
| } | |||||
| static TFE_Op op; | |||||
| private static TFE_Op ReleaseThreadLocalOp() | |||||
| { | |||||
| return op; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -126,19 +146,19 @@ namespace Tensorflow.Eager | |||||
| { | { | ||||
| TFE_TensorHandle input_handle; | TFE_TensorHandle input_handle; | ||||
| // ConvertToTensor(); | |||||
| switch (inputs) | switch (inputs) | ||||
| { | { | ||||
| case Tensor input: | |||||
| input_handle = c_api.TFE_NewTensorHandle(input, status); | |||||
| case EagerTensor input: | |||||
| input_handle = (TFE_TensorHandle)input; | |||||
| break; | break; | ||||
| case Tensor[] input_list: | |||||
| input_handle = c_api.TFE_NewTensorHandle(input_list[0], status); | |||||
| case EagerTensor[] input_list: | |||||
| input_handle = (TFE_TensorHandle)input_list[0]; | |||||
| break; | break; | ||||
| default: | default: | ||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| } | } | ||||
| if(add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr)) | if(add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr)) | ||||
| { | { | ||||
| var dtype = c_api.TFE_TensorHandleDataType(input_handle); | var dtype = c_api.TFE_TensorHandleDataType(input_handle); | ||||
| @@ -0,0 +1,63 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Gradients | |||||
| { | |||||
| /// <summary> | |||||
| /// Record operations for automatic differentiation. | |||||
| /// | |||||
| /// Operations are recorded if they are executed within this context manager and | |||||
| /// at least one of their inputs is being "watched". | |||||
| /// | |||||
| /// Trainable variables (created by `tf.Variable` or `tf.compat.v1.get_variable`, | |||||
| /// where `trainable=True` is default in both cases) are automatically watched. | |||||
| /// Tensors can be manually watched by invoking the `watch` method on this context | |||||
| /// manager. | |||||
| /// </summary> | |||||
| public class GradientActor : IDisposable | |||||
| { | |||||
| bool _recording; | |||||
| bool _persistent; | |||||
| bool _watch_accessed_variables; | |||||
| bool _created_eagerly; | |||||
| Tape _tape; | |||||
| int tape_nesting_id_counter = 0; | |||||
| public GradientActor(bool persistent = false, | |||||
| bool watch_accessed_variables = true) | |||||
| { | |||||
| _persistent = persistent; | |||||
| _watch_accessed_variables = watch_accessed_variables; | |||||
| _created_eagerly = tf.context.executing_eagerly(); | |||||
| _push_tape(); | |||||
| } | |||||
| private void _push_tape() | |||||
| { | |||||
| if (_recording) | |||||
| throw new ValueError("Tape is still recording, This can happen if you try to " + | |||||
| "re-enter an already-active tape."); | |||||
| if (_tape == null) | |||||
| { | |||||
| _tape = new Tape(); | |||||
| _tape.tape = new GradientTape(_persistent, _watch_accessed_variables); | |||||
| _tape.nesting_id = tape_nesting_id_counter++; | |||||
| } | |||||
| _recording = true; | |||||
| } | |||||
| public void watch(Tensor x) | |||||
| { | |||||
| } | |||||
| public void Dispose() | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,31 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Gradients | |||||
| { | |||||
| /// <summary> | |||||
| /// Record operations for automatic differentiation. | |||||
| /// | |||||
| /// Operations are recorded if they are executed within this context manager and | |||||
| /// at least one of their inputs is being "watched". | |||||
| /// | |||||
| /// Trainable variables (created by `tf.Variable` or `tf.compat.v1.get_variable`, | |||||
| /// where `trainable=True` is default in both cases) are automatically watched. | |||||
| /// Tensors can be manually watched by invoking the `watch` method on this context | |||||
| /// manager. | |||||
| /// </summary> | |||||
| public class GradientTape | |||||
| { | |||||
| bool _persistent; | |||||
| bool _watch_accessed_variables; | |||||
| public GradientTape(bool persistent = false, | |||||
| bool watch_accessed_variables = true) | |||||
| { | |||||
| _persistent = persistent; | |||||
| _watch_accessed_variables = watch_accessed_variables; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,7 +1,14 @@ | |||||
| namespace Tensorflow.Eager | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Gradients | |||||
| { | { | ||||
| public class Tape | public class Tape | ||||
| { | { | ||||
| public GradientTape tape { get; set; } | |||||
| public int nesting_id { get; set; } | |||||
| public static bool IsDtypeTrainable(DataType dtype) | public static bool IsDtypeTrainable(DataType dtype) | ||||
| { | { | ||||
| switch (dtype) | switch (dtype) | ||||
| @@ -76,9 +76,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| /// <remarks>https://www.tensorflow.org/guide/graphs <br></br>https://www.tensorflow.org/api_docs/python/tf/Graph</remarks> | /// <remarks>https://www.tensorflow.org/guide/graphs <br></br>https://www.tensorflow.org/api_docs/python/tf/Graph</remarks> | ||||
| public partial class Graph : DisposableObject | public partial class Graph : DisposableObject | ||||
| #if !SERIALIZABLE | |||||
| , IEnumerable<Operation> | , IEnumerable<Operation> | ||||
| #endif | |||||
| { | { | ||||
| private Dictionary<int, ITensorOrOperation> _nodes_by_id; | private Dictionary<int, ITensorOrOperation> _nodes_by_id; | ||||
| public Dictionary<string, ITensorOrOperation> _nodes_by_name; | public Dictionary<string, ITensorOrOperation> _nodes_by_name; | ||||
| @@ -541,7 +539,6 @@ namespace Tensorflow | |||||
| return debugString;*/ | return debugString;*/ | ||||
| } | } | ||||
| #if !SERIALIZABLE | |||||
| private IEnumerable<Operation> GetEnumerable() | private IEnumerable<Operation> GetEnumerable() | ||||
| => c_api_util.tf_operations(this); | => c_api_util.tf_operations(this); | ||||
| @@ -550,7 +547,6 @@ namespace Tensorflow | |||||
| IEnumerator IEnumerable.GetEnumerator() | IEnumerator IEnumerable.GetEnumerator() | ||||
| => throw new NotImplementedException(); | => throw new NotImplementedException(); | ||||
| #endif | |||||
| public static implicit operator IntPtr(Graph graph) | public static implicit operator IntPtr(Graph graph) | ||||
| { | { | ||||
| @@ -0,0 +1,14 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.Optimizers | |||||
| { | |||||
| public class SGD | |||||
| { | |||||
| public SGD(float learning_rate) | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -17,9 +17,6 @@ | |||||
| using System; | using System; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| #if SERIALIZABLE | |||||
| using Newtonsoft.Json; | |||||
| #endif | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -80,9 +77,6 @@ namespace Tensorflow | |||||
| /// reasons, or to ensure that the side effects of an op are observed | /// reasons, or to ensure that the side effects of an op are observed | ||||
| /// in the correct order. | /// in the correct order. | ||||
| /// </summary> | /// </summary> | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public Operation[] control_inputs | public Operation[] control_inputs | ||||
| { | { | ||||
| get | get | ||||
| @@ -17,9 +17,6 @@ | |||||
| using System; | using System; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| #if SERIALIZABLE | |||||
| using Newtonsoft.Json; | |||||
| #endif | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -42,13 +39,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| private Tensor[] _outputs; | private Tensor[] _outputs; | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public Tensor[] outputs => _outputs; | public Tensor[] outputs => _outputs; | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public Tensor output => _outputs.FirstOrDefault(); | public Tensor output => _outputs.FirstOrDefault(); | ||||
| public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | ||||
| @@ -60,9 +51,6 @@ namespace Tensorflow | |||||
| /// <summary> | /// <summary> | ||||
| /// List this operation's output types. | /// List this operation's output types. | ||||
| /// </summary> | /// </summary> | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public TF_DataType[] _output_types | public TF_DataType[] _output_types | ||||
| { | { | ||||
| get | get | ||||
| @@ -15,15 +15,11 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using Google.Protobuf.Collections; | using Google.Protobuf.Collections; | ||||
| #if SERIALIZABLE | |||||
| using Newtonsoft.Json; | |||||
| #endif | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.IO; | using System.IO; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -51,40 +47,23 @@ namespace Tensorflow | |||||
| private readonly Graph _graph; | private readonly Graph _graph; | ||||
| private NodeDef _node_def; | private NodeDef _node_def; | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public string type => OpType; | public string type => OpType; | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public Graph graph => _graph; | public Graph graph => _graph; | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public int _id => _id_value; | public int _id => _id_value; | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public int _id_value { get; set; } | public int _id_value { get; set; } | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public Operation op => this; | public Operation op => this; | ||||
| public TF_DataType dtype => TF_DataType.DtInvalid; | public TF_DataType dtype => TF_DataType.DtInvalid; | ||||
| public string name => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationName(_handle)); | public string name => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationName(_handle)); | ||||
| public string OpType => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | public string OpType => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public string Device => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | public string Device => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| bool _is_stateful; | bool _is_stateful; | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public NodeDef node_def | public NodeDef node_def | ||||
| { | { | ||||
| get | get | ||||
| @@ -57,7 +57,7 @@ namespace Tensorflow | |||||
| case TF_DataType.TF_FLOAT: | case TF_DataType.TF_FLOAT: | ||||
| return _constant_if_small(0.0F, shape, dtype, name); | return _constant_if_small(0.0F, shape, dtype, name); | ||||
| case TF_DataType.TF_INT64: | case TF_DataType.TF_INT64: | ||||
| return _constant_if_small(0l, shape, dtype, name); | |||||
| return _constant_if_small(0L, shape, dtype, name); | |||||
| case TF_DataType.TF_INT32: | case TF_DataType.TF_INT32: | ||||
| return _constant_if_small(0, shape, dtype, name); | return _constant_if_small(0, shape, dtype, name); | ||||
| case TF_DataType.TF_INT8: | case TF_DataType.TF_INT8: | ||||
| @@ -86,7 +86,7 @@ namespace Tensorflow | |||||
| var shape1 = concat(new[] | var shape1 = concat(new[] | ||||
| { | { | ||||
| shape(tensor_tensor)[$":{axis}"], | shape(tensor_tensor)[$":{axis}"], | ||||
| tf.expand_dims(leading_size, 0), | |||||
| leading_size, | |||||
| shape(tensor_tensor)[$"{axis + ndims_mask}:"] | shape(tensor_tensor)[$"{axis + ndims_mask}:"] | ||||
| }, 0); | }, 0); | ||||
| tensor_tensor = reshape(tensor, shape1); | tensor_tensor = reshape(tensor, shape1); | ||||
| @@ -136,16 +136,16 @@ namespace Tensorflow | |||||
| private static Tensor _constant_if_small<T>(T value, TensorShape shape, TF_DataType dtype, string name) | private static Tensor _constant_if_small<T>(T value, TensorShape shape, TF_DataType dtype, string name) | ||||
| { | { | ||||
| Tensor tShape = null; | |||||
| Tensor shape_t = null; | |||||
| if (shape.size < 1000) | if (shape.size < 1000) | ||||
| { | { | ||||
| return constant_op.constant(value, shape: shape, dtype: dtype, name: name); | return constant_op.constant(value, shape: shape, dtype: dtype, name: name); | ||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| tShape = constant_op._tensor_shape_tensor_conversion_function(shape); | |||||
| shape_t = constant_op._tensor_shape_tensor_conversion_function(shape); | |||||
| var c = constant_op.constant(0, dtype: dtype); | var c = constant_op.constant(0, dtype: dtype); | ||||
| return gen_array_ops.fill(tShape, c, name: name); | |||||
| return gen_array_ops.fill(shape_t, c, name: name); | |||||
| } | } | ||||
| } | } | ||||
| @@ -313,15 +313,20 @@ namespace Tensorflow | |||||
| } | } | ||||
| public static Tensor ones(int[] dims, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | public static Tensor ones(int[] dims, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | ||||
| { | |||||
| dtype = dtype.as_base_dtype(); | |||||
| return tf_with(ops.name_scope(name, "ones", new { dims }), scope => | |||||
| => tf_with(ops.name_scope(name, "ones", new { dims }), scope => | |||||
| { | { | ||||
| dtype = dtype.as_base_dtype(); | |||||
| name = scope; | name = scope; | ||||
| var output = _constant_if_small(1, dims, dtype, name); | |||||
| return output; | |||||
| switch (dtype) | |||||
| { | |||||
| case TF_DataType.TF_DOUBLE: | |||||
| return _constant_if_small(1.0d, dims, dtype, name); | |||||
| case TF_DataType.TF_FLOAT: | |||||
| return _constant_if_small(1.0f, dims, dtype, name); | |||||
| default: | |||||
| return _constant_if_small(1, dims, dtype, name); | |||||
| } | |||||
| }); | }); | ||||
| } | |||||
| public static Tensor one_hot(Tensor indices, int depth, | public static Tensor one_hot(Tensor indices, int depth, | ||||
| Tensor on_value = null, | Tensor on_value = null, | ||||
| @@ -19,6 +19,8 @@ using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using System.Linq; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -50,9 +52,34 @@ namespace Tensorflow | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Tensor concat_v2<T, Ta>(T[] values, Ta axis, string name = null) | public static Tensor concat_v2<T, Ta>(T[] values, Ta axis, string name = null) | ||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| try | |||||
| { | |||||
| var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "ConcatV2", name, null, | |||||
| values, axis); | |||||
| return _result; | |||||
| } | |||||
| catch (Exception) | |||||
| { | |||||
| return concat_v2_eager_fallback(values, axis, name, tf.context); | |||||
| } | |||||
| } | |||||
| var _op = _op_def_lib._apply_op_helper("ConcatV2", name: name, args: new { values, axis }); | var _op = _op_def_lib._apply_op_helper("ConcatV2", name: name, args: new { values, axis }); | ||||
| return _op.output; | |||||
| } | |||||
| return _op.outputs[0]; | |||||
| private static Tensor concat_v2_eager_fallback<T1, T2>(T1[] values, T2 axis, string name, Context ctx) | |||||
| { | |||||
| var _attr_N = len(values); | |||||
| var (_attr_T, input) = _execute.args_to_matching_eager(ctx, args: values.Select(x => (object)x).ToArray()); | |||||
| var (_attr_Tidx, axis1) = _execute.args_to_matching_eager(ctx, default_dtype: tf.int32, args: new object[] { axis }); | |||||
| var _inputs_flat = input.concat(axis1); | |||||
| var _attrs = new object[] { "N", _attr_N, "T", _attr_T, "Tidx", _attr_Tidx }; | |||||
| return _execute.execute(ctx, "ConcatV2", _inputs_flat, _attrs, name: name); | |||||
| } | } | ||||
| public static Tensor[] concat_offset(Tensor concat_dim, Tensor[] shape, string name = null) | public static Tensor[] concat_offset(Tensor concat_dim, Tensor[] shape, string name = null) | ||||
| @@ -130,8 +157,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| var _op = _op_def_lib._apply_op_helper("Pack", name: name, args: new { values, axis }); | var _op = _op_def_lib._apply_op_helper("Pack", name: name, args: new { values, axis }); | ||||
| return _op.outputs[0]; | |||||
| return _op.output; | |||||
| } | } | ||||
| public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null) | public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null) | ||||
| @@ -223,9 +249,16 @@ namespace Tensorflow | |||||
| /// <returns>A `Tensor`. Has the same type as `value`.</returns> | /// <returns>A `Tensor`. Has the same type as `value`.</returns> | ||||
| public static Tensor fill<T>(Tensor dims, T value, string name = null) | public static Tensor fill<T>(Tensor dims, T value, string name = null) | ||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("Fill", name, new { dims, value }); | |||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "Fill", name, null, | |||||
| dims, value); | |||||
| return _result; | |||||
| } | |||||
| return _op.outputs[0]; | |||||
| var _op = _op_def_lib._apply_op_helper("Fill", name, new { dims, value }); | |||||
| return _op.output; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -325,6 +358,14 @@ namespace Tensorflow | |||||
| public static Tensor shape(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null) | public static Tensor shape(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null) | ||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "Shape", name, null, | |||||
| input, "out_type", out_type); | |||||
| return _result; | |||||
| } | |||||
| var _op = _op_def_lib._apply_op_helper("Shape", name, new { input, out_type }); | var _op = _op_def_lib._apply_op_helper("Shape", name, new { input, out_type }); | ||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| } | } | ||||
| @@ -401,6 +442,16 @@ namespace Tensorflow | |||||
| int shrink_axis_mask = 0, | int shrink_axis_mask = 0, | ||||
| string name = null) | string name = null) | ||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "StridedSlice", name, null, | |||||
| input, begin, end, strides, "begin_mask", begin_mask, | |||||
| "end_mask", end_mask, "ellipsis_mask", ellipsis_mask, | |||||
| "new_axis_mask", new_axis_mask, "shrink_axis_mask", shrink_axis_mask); | |||||
| return _result; | |||||
| } | |||||
| var _op = _op_def_lib._apply_op_helper("StridedSlice", name, new | var _op = _op_def_lib._apply_op_helper("StridedSlice", name, new | ||||
| { | { | ||||
| input, | input, | ||||
| @@ -15,6 +15,7 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System; | using System; | ||||
| using System.Linq; | |||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -120,10 +121,12 @@ namespace Tensorflow | |||||
| { | { | ||||
| try | try | ||||
| { | { | ||||
| var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, "Mean", name, null, input, axis, "keep_dims", keep_dims); | |||||
| var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "Mean", name, null, | |||||
| input, axis, "keep_dims", keep_dims); | |||||
| return _result; | return _result; | ||||
| } | } | ||||
| catch (Exception ex) | |||||
| catch (Exception) | |||||
| { | { | ||||
| return mean_eager_fallback(input as Tensor[], axis as Tensor, keep_dims: keep_dims, name: name, ctx: tf.context); | return mean_eager_fallback(input as Tensor[], axis as Tensor, keep_dims: keep_dims, name: name, ctx: tf.context); | ||||
| } | } | ||||
| @@ -136,21 +139,43 @@ namespace Tensorflow | |||||
| private static Tensor mean_eager_fallback(Tensor[] inputs, Tensor axis, bool keep_dims = false, string name = null, Context ctx = null) | private static Tensor mean_eager_fallback(Tensor[] inputs, Tensor axis, bool keep_dims = false, string name = null, Context ctx = null) | ||||
| { | { | ||||
| var (_attr_T, input) = _execute.args_to_matching_eager(inputs, ctx); | |||||
| var (_attr_Tidx, axis1) = _execute.args_to_matching_eager(new[] { axis }, ctx, TF_DataType.TF_INT32); | |||||
| var _inputs_flat = new Tensor[] { input, axis1 }; | |||||
| var (_attr_T, input) = _execute.args_to_matching_eager(ctx, args: new[] { inputs }); | |||||
| var (_attr_Tidx, axis1) = _execute.args_to_matching_eager(ctx, default_dtype: tf.int32, args: new[] { axis }); | |||||
| var _inputs_flat = input.concat(axis1); | |||||
| var _attrs = new object[] { "keep_dims", keep_dims, "T", _attr_T, "Tidx", _attr_Tidx }; | var _attrs = new object[] { "keep_dims", keep_dims, "T", _attr_T, "Tidx", _attr_Tidx }; | ||||
| var _result = _execute.execute(ctx, "Mean", _inputs_flat, _attrs, name: name); | |||||
| return _result; | |||||
| return _execute.execute(ctx, "Mean", _inputs_flat, _attrs, name: name); | |||||
| } | } | ||||
| public static Tensor prod<T1, T2>(T1 input, T2 axis, bool keep_dims = false, string name = null) | public static Tensor prod<T1, T2>(T1 input, T2 axis, bool keep_dims = false, string name = null) | ||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| try | |||||
| { | |||||
| var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "Prod", name, null, | |||||
| input, axis, "keep_dims", keep_dims); | |||||
| return _result; | |||||
| } | |||||
| catch (Exception) | |||||
| { | |||||
| return prod_eager_fallback(input as Tensor, axis as int[], keep_dims, name, tf.context); | |||||
| } | |||||
| } | |||||
| var _op = _op_def_lib._apply_op_helper("Prod", name, args: new { input, reduction_indices = axis, keep_dims }); | var _op = _op_def_lib._apply_op_helper("Prod", name, args: new { input, reduction_indices = axis, keep_dims }); | ||||
| return _op.output; | |||||
| } | |||||
| return _op.outputs[0]; | |||||
| private static Tensor prod_eager_fallback(Tensor input_t, int[] axis, bool keep_dims, string name, Context ctx = null) | |||||
| { | |||||
| var (_attr_T, input) = _execute.args_to_matching_eager(ctx, args: new[] { input_t }); | |||||
| var (_attr_Tidx, axis1) = _execute.args_to_matching_eager(ctx, default_dtype: tf.int32, args: new[] { axis }); | |||||
| var _inputs_flat = input.concat(axis1); | |||||
| var _attrs = new object[] { "keep_dims", keep_dims, "T", _attr_T, "Tidx", _attr_Tidx }; | |||||
| return _execute.execute(ctx, "Prod", _inputs_flat, _attrs, name: name); | |||||
| } | } | ||||
| public static Tensor acos(Tensor x, string name = null) | public static Tensor acos(Tensor x, string name = null) | ||||
| @@ -171,7 +196,9 @@ namespace Tensorflow | |||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
| { | { | ||||
| var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, "", "Add", name, null, x, y); | |||||
| var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "Add", name, null, | |||||
| x, y); | |||||
| return _result; | return _result; | ||||
| } | } | ||||
| @@ -183,6 +210,14 @@ namespace Tensorflow | |||||
| public static Tensor add_v2<Tx, Ty>(Tx x, Ty y, string name = null) | public static Tensor add_v2<Tx, Ty>(Tx x, Ty y, string name = null) | ||||
| { | { | ||||
| // forward_compatible(2019, 6, 25): | // forward_compatible(2019, 6, 25): | ||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "AddV2", name, null, | |||||
| x, y); | |||||
| return _result; | |||||
| } | |||||
| var _op = _op_def_lib._apply_op_helper("AddV2", name, args: new { x, y }); | var _op = _op_def_lib._apply_op_helper("AddV2", name, args: new { x, y }); | ||||
| return _op.output; | return _op.output; | ||||
| @@ -517,7 +552,9 @@ namespace Tensorflow | |||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
| { | { | ||||
| var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, "", "Cast", name, null, x, "DstT", DstT, "Truncate", Truncate); | |||||
| var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "Cast", name, null, | |||||
| x, "DstT", DstT, "Truncate", Truncate); | |||||
| return _result; | return _result; | ||||
| } | } | ||||
| @@ -528,6 +565,14 @@ namespace Tensorflow | |||||
| public static Tensor neg(Tensor x, string name = null) | public static Tensor neg(Tensor x, string name = null) | ||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "Neg", name, null, | |||||
| x); | |||||
| return _result; | |||||
| } | |||||
| var _op = _op_def_lib._apply_op_helper("Neg", name, args: new { x }); | var _op = _op_def_lib._apply_op_helper("Neg", name, args: new { x }); | ||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| @@ -535,6 +580,14 @@ namespace Tensorflow | |||||
| public static Tensor sqrt(Tensor x, string name = null) | public static Tensor sqrt(Tensor x, string name = null) | ||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "Sqrt", name, null, | |||||
| x); | |||||
| return _result; | |||||
| } | |||||
| var _op = _op_def_lib._apply_op_helper("Sqrt", name, args: new { x }); | var _op = _op_def_lib._apply_op_helper("Sqrt", name, args: new { x }); | ||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| @@ -544,7 +597,9 @@ namespace Tensorflow | |||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
| { | { | ||||
| var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, "", "Sub", name, null, x, y); | |||||
| var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "Sub", name, null, | |||||
| x, y); | |||||
| return _result; | return _result; | ||||
| } | } | ||||
| @@ -562,9 +617,16 @@ namespace Tensorflow | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Tensor equal<Tx, Ty>(Tx x, Ty y, string name = null) | public static Tensor equal<Tx, Ty>(Tx x, Ty y, string name = null) | ||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("Equal", name, args: new { x, y }); | |||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "Equal", name, null, | |||||
| x, y); | |||||
| return _result; | |||||
| } | |||||
| return _op.outputs[0]; | |||||
| var _op = _op_def_lib._apply_op_helper("Equal", name, args: new { x, y }); | |||||
| return _op.output; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -578,24 +640,40 @@ namespace Tensorflow | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Tensor not_equal<Tx, Ty>(Tx x, Ty y, string name = null) | public static Tensor not_equal<Tx, Ty>(Tx x, Ty y, string name = null) | ||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("NotEqual", name, args: new { x, y }); | |||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "NotEqual", name, null, | |||||
| x, y); | |||||
| return _result; | |||||
| } | |||||
| return _op.outputs[0]; | |||||
| var _op = _op_def_lib._apply_op_helper("NotEqual", name, args: new { x, y }); | |||||
| return _op.output; | |||||
| } | } | ||||
| public static Tensor atan2(Tensor y, Tensor x, string name = null) | public static Tensor atan2(Tensor y, Tensor x, string name = null) | ||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("Atan2", name, args: new { y, x }); | |||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "Atan2", name, null, | |||||
| y, x); | |||||
| return _result; | |||||
| } | |||||
| return _op.outputs[0]; | |||||
| var _op = _op_def_lib._apply_op_helper("Atan2", name, args: new { y, x }); | |||||
| return _op.output; | |||||
| } | } | ||||
| public static Tensor mul<Tx, Ty>(Tx x, Ty y, string name = null) | public static Tensor mul<Tx, Ty>(Tx x, Ty y, string name = null) | ||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
| { | { | ||||
| var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, "", "Mul", name, null, x, y); | |||||
| var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "Mul", name, null, | |||||
| x, y); | |||||
| return _result; | return _result; | ||||
| } | } | ||||
| @@ -791,14 +869,12 @@ namespace Tensorflow | |||||
| private static Tensor _sum_eager_fallback(Tensor[] inputs, Tensor axis, bool keep_dims = false, string name = null, Context ctx = null) | private static Tensor _sum_eager_fallback(Tensor[] inputs, Tensor axis, bool keep_dims = false, string name = null, Context ctx = null) | ||||
| { | { | ||||
| var (_attr_T, input) = _execute.args_to_matching_eager(inputs, ctx); | |||||
| var (_attr_Tidx, axis1) = _execute.args_to_matching_eager(new[] { axis }, ctx, TF_DataType.TF_INT32); | |||||
| var _inputs_flat = new Tensor[] { input, axis1 }; | |||||
| var (_attr_T, input) = _execute.args_to_matching_eager(ctx, args: new[] { inputs }); | |||||
| var (_attr_Tidx, axis1) = _execute.args_to_matching_eager(ctx, tf.int32, new[] { axis }); | |||||
| var _inputs_flat = input.concat(axis1); | |||||
| var _attrs = new object[] { "keep_dims", keep_dims, "T", _attr_T, "Tidx", _attr_Tidx }; | var _attrs = new object[] { "keep_dims", keep_dims, "T", _attr_T, "Tidx", _attr_Tidx }; | ||||
| var _result = _execute.execute(ctx, "Sum", _inputs_flat, _attrs, name: name); | |||||
| return _result; | |||||
| return _execute.execute(ctx, "Sum", _inputs_flat, _attrs, name: name); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -25,6 +25,14 @@ namespace Tensorflow | |||||
| public static Operation assign_variable_op(Tensor resource, Tensor value, string name = null) | public static Operation assign_variable_op(Tensor resource, Tensor value, string name = null) | ||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "AssignVariableOp", name, null, | |||||
| resource, value); | |||||
| return _result; | |||||
| } | |||||
| var _op = _op_def_lib._apply_op_helper("AssignVariableOp", name, new { resource, value }); | var _op = _op_def_lib._apply_op_helper("AssignVariableOp", name, new { resource, value }); | ||||
| return _op; | return _op; | ||||
| @@ -84,6 +92,14 @@ namespace Tensorflow | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Tensor read_variable_op(Tensor resource, TF_DataType dtype, string name = null) | public static Tensor read_variable_op(Tensor resource, TF_DataType dtype, string name = null) | ||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "ReadVariableOp", name, null, | |||||
| resource, "dtype", dtype); | |||||
| return _result; | |||||
| } | |||||
| var _op = _op_def_lib._apply_op_helper("ReadVariableOp", name, new | var _op = _op_def_lib._apply_op_helper("ReadVariableOp", name, new | ||||
| { | { | ||||
| resource, | resource, | ||||
| @@ -128,8 +128,9 @@ namespace Tensorflow | |||||
| // When in eager mode, explicitly ensure so here. When in graph mode, it's | // When in eager mode, explicitly ensure so here. When in graph mode, it's | ||||
| // ensured by always generating different variable names. | // ensured by always generating different variable names. | ||||
| var exists = gen_resource_variable_ops.var_is_initialized_op(handle); | var exists = gen_resource_variable_ops.var_is_initialized_op(handle); | ||||
| throw new NotImplementedException(""); | |||||
| } | } | ||||
| return handle; | |||||
| } | } | ||||
| private static void _set_handle_shapes_and_types(Tensor handle, HandleData full_handle_data, bool graph_mode) | private static void _set_handle_shapes_and_types(Tensor handle, HandleData full_handle_data, bool graph_mode) | ||||
| @@ -4,7 +4,7 @@ | |||||
| <TargetFramework>netstandard2.0</TargetFramework> | <TargetFramework>netstandard2.0</TargetFramework> | ||||
| <AssemblyName>TensorFlow.NET</AssemblyName> | <AssemblyName>TensorFlow.NET</AssemblyName> | ||||
| <RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
| <TargetTensorFlow>1.14.1</TargetTensorFlow> | |||||
| <TargetTensorFlow>2.01.0</TargetTensorFlow> | |||||
| <Version>0.20.0</Version> | <Version>0.20.0</Version> | ||||
| <LangVersion>8.0</LangVersion> | <LangVersion>8.0</LangVersion> | ||||
| <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | ||||
| @@ -36,7 +36,7 @@ https://tensorflownet.readthedocs.io</Description> | |||||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | ||||
| <AllowUnsafeBlocks>true</AllowUnsafeBlocks> | <AllowUnsafeBlocks>true</AllowUnsafeBlocks> | ||||
| <DefineConstants>TRACE;DEBUG;SERIALIZABLE_</DefineConstants> | |||||
| <DefineConstants>TRACE;DEBUG</DefineConstants> | |||||
| <PlatformTarget>x64</PlatformTarget> | <PlatformTarget>x64</PlatformTarget> | ||||
| </PropertyGroup> | </PropertyGroup> | ||||
| @@ -18,19 +18,9 @@ using NumSharp; | |||||
| using System; | using System; | ||||
| using System.Diagnostics.CodeAnalysis; | using System.Diagnostics.CodeAnalysis; | ||||
| using System.Globalization; | using System.Globalization; | ||||
| using System.Linq; | |||||
| using System.Numerics; | |||||
| using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
| using System.Runtime.InteropServices; | |||||
| using System.Text; | using System.Text; | ||||
| using NumSharp.Backends; | |||||
| using NumSharp.Backends.Unmanaged; | |||||
| using NumSharp.Utilities; | using NumSharp.Utilities; | ||||
| using static Tensorflow.c_api; | |||||
| #if SERIALIZABLE | |||||
| using Newtonsoft.Json; | |||||
| #endif | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -22,12 +22,7 @@ using System.Numerics; | |||||
| using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using System.Text; | using System.Text; | ||||
| using NumSharp.Backends; | |||||
| using NumSharp.Backends.Unmanaged; | |||||
| using static Tensorflow.c_api; | using static Tensorflow.c_api; | ||||
| #if SERIALIZABLE | |||||
| using Newtonsoft.Json; | |||||
| #endif | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -47,17 +42,11 @@ namespace Tensorflow | |||||
| /// <summary> | /// <summary> | ||||
| /// True if this Tensor holds data allocated by C#. | /// True if this Tensor holds data allocated by C#. | ||||
| /// </summary> | /// </summary> | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public bool IsMemoryOwner => AllocationType >= AllocationType.Marshal; | public bool IsMemoryOwner => AllocationType >= AllocationType.Marshal; | ||||
| /// <summary> | /// <summary> | ||||
| /// The allocation method used to create this Tensor. | /// The allocation method used to create this Tensor. | ||||
| /// </summary> | /// </summary> | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public AllocationType AllocationType { get; protected set; } | public AllocationType AllocationType { get; protected set; } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -16,7 +16,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| public static implicit operator Operation(Tensor tensor) | public static implicit operator Operation(Tensor tensor) | ||||
| => tensor.op; | |||||
| => tensor?.op; | |||||
| public static implicit operator TF_Tensor(Tensor tensor) | public static implicit operator TF_Tensor(Tensor tensor) | ||||
| => new TF_Tensor(tensor._handle); | => new TF_Tensor(tensor._handle); | ||||
| @@ -150,30 +150,26 @@ namespace Tensorflow | |||||
| /// Tensor has rank 0. | /// Tensor has rank 0. | ||||
| /// </returns> | /// </returns> | ||||
| public NDArray numpy() | public NDArray numpy() | ||||
| => NDims == 0 ? GetScalar(dtype) : GetNDArray(dtype); | |||||
| protected unsafe NDArray GetNDArray(TF_DataType dtype) | |||||
| { | { | ||||
| if(NDims == 0) | |||||
| { | |||||
| return GetScalar(dtype); | |||||
| } | |||||
| else | |||||
| switch (dtype) | |||||
| { | { | ||||
| switch (dtype) | |||||
| { | |||||
| case TF_DataType.TF_STRING: | |||||
| return StringData(); | |||||
| case TF_DataType.TF_INT32: | |||||
| return ToArray<int>(); | |||||
| case TF_DataType.TF_FLOAT: | |||||
| return ToArray<float>(); | |||||
| case TF_DataType.TF_DOUBLE: | |||||
| return ToArray<double>(); | |||||
| default: | |||||
| return BufferToArray(); | |||||
| } | |||||
| case TF_DataType.TF_STRING: | |||||
| return StringData(); | |||||
| case TF_DataType.TF_INT32: | |||||
| return ToArray<int>(); | |||||
| case TF_DataType.TF_FLOAT: | |||||
| return ToArray<float>(); | |||||
| case TF_DataType.TF_DOUBLE: | |||||
| return ToArray<double>(); | |||||
| default: | |||||
| return BufferToArray(); | |||||
| } | } | ||||
| } | } | ||||
| private unsafe NDArray GetScalar(TF_DataType dtype) | |||||
| protected unsafe NDArray GetScalar(TF_DataType dtype) | |||||
| { | { | ||||
| switch(dtype) | switch(dtype) | ||||
| { | { | ||||
| @@ -24,9 +24,6 @@ using System.Runtime.InteropServices; | |||||
| using System.Text; | using System.Text; | ||||
| using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
| using Tensorflow.Framework; | using Tensorflow.Framework; | ||||
| #if SERIALIZABLE | |||||
| using Newtonsoft.Json; | |||||
| #endif | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -47,33 +44,22 @@ namespace Tensorflow | |||||
| private readonly int _value_index; | private readonly int _value_index; | ||||
| private TF_Output? _tf_output; | private TF_Output? _tf_output; | ||||
| private readonly TF_DataType _override_dtype; | private readonly TF_DataType _override_dtype; | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public int Id => _id; | public int Id => _id; | ||||
| /// <summary> | /// <summary> | ||||
| /// The Graph that contains this tensor. | /// The Graph that contains this tensor. | ||||
| /// </summary> | /// </summary> | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public Graph graph => op?.graph; | public Graph graph => op?.graph; | ||||
| /// <summary> | /// <summary> | ||||
| /// The Operation that produces this tensor as an output. | /// The Operation that produces this tensor as an output. | ||||
| /// </summary> | /// </summary> | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public Operation op => _op; | public Operation op => _op; | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public Tensor[] outputs => op.outputs; | public Tensor[] outputs => op.outputs; | ||||
| /// <summary> | /// <summary> | ||||
| /// The string name of this tensor. | |||||
| /// The string name of this tensor.<br/> | |||||
| /// Tensor.name is meaningless when eager execution is enabled. | |||||
| /// </summary> | /// </summary> | ||||
| public string name => $"{(op == null ? "<unnamed>" : $"{op.name}:{_value_index}")}"; | public string name => $"{(op == null ? "<unnamed>" : $"{op.name}:{_value_index}")}"; | ||||
| @@ -86,48 +72,28 @@ namespace Tensorflow | |||||
| /// The DType of elements in this tensor. | /// The DType of elements in this tensor. | ||||
| /// </summary> | /// </summary> | ||||
| public TF_DataType dtype => _handle == IntPtr.Zero ? _override_dtype : c_api.TF_TensorType(_handle); | public TF_DataType dtype => _handle == IntPtr.Zero ? _override_dtype : c_api.TF_TensorType(_handle); | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); | public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); | public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; | public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); | public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); | ||||
| public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public int NDims => rank; | public int NDims => rank; | ||||
| /// <summary> | /// <summary> | ||||
| /// The name of the device on which this tensor will be produced, or null. | /// The name of the device on which this tensor will be produced, or null. | ||||
| /// </summary> | /// </summary> | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public string Device => op.Device; | public string Device => op.Device; | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public int[] dims => shape; | public int[] dims => shape; | ||||
| /// <summary> | /// <summary> | ||||
| /// Used for keep other pointer when do implicit operating | /// Used for keep other pointer when do implicit operating | ||||
| /// </summary> | /// </summary> | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public object Tag { get; set; } | public object Tag { get; set; } | ||||
| /// <summary> | |||||
| /// Associated resource variable | |||||
| /// </summary> | |||||
| public ResourceVariable ResourceVar { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the shape of a tensor. | /// Returns the shape of a tensor. | ||||
| @@ -175,9 +141,6 @@ namespace Tensorflow | |||||
| return rank < 0 ? null : shape; | return rank < 0 ? null : shape; | ||||
| } | } | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public TensorShape TensorShape => rank < 0 ? new TensorShape() : tensor_util.to_shape(shape); | public TensorShape TensorShape => rank < 0 ? new TensorShape() : tensor_util.to_shape(shape); | ||||
| /// <summary> | /// <summary> | ||||
| @@ -316,9 +279,6 @@ namespace Tensorflow | |||||
| } else | } else | ||||
| throw new InvalidOperationException($"Tensor.AllocationHandle is not null ({AllocationHandle}) but AllocationType is not matched to a C# allocation type ({AllocationType})."); | throw new InvalidOperationException($"Tensor.AllocationHandle is not null ({AllocationHandle}) but AllocationType is not matched to a C# allocation type ({AllocationType})."); | ||||
| } | } | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public bool IsDisposed => _disposed; | public bool IsDisposed => _disposed; | ||||
| // public int tensor_int_val { get; set; } | // public int tensor_int_val { get; set; } | ||||
| @@ -4,9 +4,6 @@ using System.Collections.Generic; | |||||
| using System.Diagnostics.CodeAnalysis; | using System.Diagnostics.CodeAnalysis; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
| #if SERIALIZABLE | |||||
| using Newtonsoft.Json; | |||||
| #endif | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -38,9 +35,6 @@ namespace Tensorflow | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the size this shape represents. | /// Returns the size this shape represents. | ||||
| /// </summary> | /// </summary> | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public int size | public int size | ||||
| { | { | ||||
| get | get | ||||
| @@ -244,6 +238,19 @@ namespace Tensorflow | |||||
| return (int[]) dims.Clone(); | return (int[]) dims.Clone(); | ||||
| } | } | ||||
| public int num_elements() | |||||
| { | |||||
| if(is_fully_defined()) | |||||
| { | |||||
| var size = 1; | |||||
| foreach (var dim in dims) | |||||
| size *= dim; | |||||
| return size; | |||||
| } | |||||
| return -1; | |||||
| } | |||||
| public override string ToString() | public override string ToString() | ||||
| { | { | ||||
| return shape.ToString(); | return shape.ToString(); | ||||
| @@ -253,7 +260,7 @@ namespace Tensorflow | |||||
| public static implicit operator Shape(TensorShape shape) => new Shape((int[]) shape.dims.Clone()); | public static implicit operator Shape(TensorShape shape) => new Shape((int[]) shape.dims.Clone()); | ||||
| public static implicit operator int[](TensorShape shape) => shape == null ? null : (int[])shape.dims.Clone(); //we clone to avoid any changes | public static implicit operator int[](TensorShape shape) => shape == null ? null : (int[])shape.dims.Clone(); //we clone to avoid any changes | ||||
| public static implicit operator TensorShape(int[] dims) => dims == null ? new TensorShape(new int[0]) : new TensorShape(dims); | |||||
| public static implicit operator TensorShape(int[] dims) => dims == null ? null : new TensorShape(dims); | |||||
| public static explicit operator int(TensorShape shape) => shape.size; | public static explicit operator int(TensorShape shape) => shape.size; | ||||
| public static implicit operator TensorShape(int dim) => new TensorShape(dim); | public static implicit operator TensorShape(int dim) => new TensorShape(dim); | ||||
| @@ -19,11 +19,14 @@ using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using System.Linq; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class constant_op | public class constant_op | ||||
| { | { | ||||
| public static Execute _execute = new Execute(); | |||||
| /// <summary> | /// <summary> | ||||
| /// Creates a constant tensor. | /// Creates a constant tensor. | ||||
| /// | /// | ||||
| @@ -43,7 +46,7 @@ namespace Tensorflow | |||||
| public static Tensor _constant_impl(object value, | public static Tensor _constant_impl(object value, | ||||
| TF_DataType dtype, | TF_DataType dtype, | ||||
| int[] shape, | |||||
| TensorShape shape, | |||||
| string name, | string name, | ||||
| bool verify_shape, | bool verify_shape, | ||||
| bool allow_broadcast) | bool allow_broadcast) | ||||
| @@ -53,6 +56,23 @@ namespace Tensorflow | |||||
| var t = convert_to_eager_tensor(value, tf.context, dtype: dtype); | var t = convert_to_eager_tensor(value, tf.context, dtype: dtype); | ||||
| if (shape == null) | if (shape == null) | ||||
| return t; | return t; | ||||
| if (t.shape.SequenceEqual(shape.dims)) | |||||
| return t; | |||||
| if (verify_shape) | |||||
| throw new TypeError($"Expected Tensor's shape: {shape}, got {t.shape}."); | |||||
| var num_t = t.TensorShape.num_elements(); | |||||
| if (num_t == shape.num_elements()) | |||||
| throw new NotImplementedException(""); | |||||
| if(num_t == 1) | |||||
| { | |||||
| if (t.dtype == dtypes.@bool) | |||||
| throw new NotImplementedException(""); | |||||
| else | |||||
| return _eager_fill(shape, t, tf.context); | |||||
| } | |||||
| } | } | ||||
| Graph g = ops.get_default_graph(); | Graph g = ops.get_default_graph(); | ||||
| @@ -81,24 +101,38 @@ namespace Tensorflow | |||||
| return op.outputs[0]; | return op.outputs[0]; | ||||
| } | } | ||||
| private static Tensor _eager_fill(int[] dims, Tensor value, Context ctx) | |||||
| { | |||||
| var attr_t = value.dtype.as_datatype_enum(); | |||||
| var dims_t = convert_to_eager_tensor(dims, ctx, dtypes.int32); | |||||
| var inputs_flat = new[] { dims_t, value }; | |||||
| var attrs = new object[] { "T", attr_t, "index_type", TF_DataType.TF_INT32 }; | |||||
| var result = _execute.execute(ctx, "Fill", inputs_flat, attrs); | |||||
| return result; | |||||
| } | |||||
| private static EagerTensor convert_to_eager_tensor(object value, Context ctx, TF_DataType dtype = TF_DataType.DtInvalid) | private static EagerTensor convert_to_eager_tensor(object value, Context ctx, TF_DataType dtype = TF_DataType.DtInvalid) | ||||
| { | { | ||||
| switch (value) | switch (value) | ||||
| { | { | ||||
| case NDArray nd: | |||||
| return new EagerTensor(nd, ctx.device_name); | |||||
| case string str: | |||||
| return new EagerTensor(str, ctx.device_name); | |||||
| case int int32: | |||||
| return new EagerTensor(int32, ctx.device_name); | |||||
| case float float32: | |||||
| return new EagerTensor(float32, ctx.device_name); | |||||
| case double double64: | |||||
| return new EagerTensor(double64, ctx.device_name); | |||||
| case float[] float32s: | |||||
| return new EagerTensor(float32s, ctx.device_name); | |||||
| case double[] double64s: | |||||
| return new EagerTensor(double64s, ctx.device_name); | |||||
| case NDArray val: | |||||
| return new EagerTensor(val, ctx.device_name); | |||||
| case string val: | |||||
| return new EagerTensor(val, ctx.device_name); | |||||
| case int val: | |||||
| return new EagerTensor(val, ctx.device_name); | |||||
| case int[] val: | |||||
| return new EagerTensor(val, ctx.device_name); | |||||
| case int[,] val: | |||||
| return new EagerTensor(val, ctx.device_name); | |||||
| case float val: | |||||
| return new EagerTensor(val, ctx.device_name); | |||||
| case double val: | |||||
| return new EagerTensor(val, ctx.device_name); | |||||
| case float[] val: | |||||
| return new EagerTensor(val, ctx.device_name); | |||||
| case double[] val: | |||||
| return new EagerTensor(val, ctx.device_name); | |||||
| default: | default: | ||||
| throw new NotImplementedException($"convert_to_eager_tensor {value.GetType()}"); | throw new NotImplementedException($"convert_to_eager_tensor {value.GetType()}"); | ||||
| } | } | ||||
| @@ -112,7 +146,10 @@ namespace Tensorflow | |||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| /// <param name="as_ref"></param> | /// <param name="as_ref"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Tensor _tensor_shape_tensor_conversion_function(TensorShape s, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false) | |||||
| public static Tensor _tensor_shape_tensor_conversion_function(TensorShape s, | |||||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||||
| string name = null, | |||||
| bool as_ref = false) | |||||
| { | { | ||||
| var s_list = s.dims; | var s_list = s.dims; | ||||
| var int64_value = 0; | var int64_value = 0; | ||||
| @@ -125,15 +162,12 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| if(int64_value > 0) | |||||
| { | |||||
| dtype = TF_DataType.TF_INT32; | |||||
| } | |||||
| dtype = int64_value > 0 ? TF_DataType.TF_INT64 : TF_DataType.TF_INT32; | |||||
| if (string.IsNullOrEmpty(name)) | if (string.IsNullOrEmpty(name)) | ||||
| name = "shape_as_tensor"; | name = "shape_as_tensor"; | ||||
| return constant_op.constant(s_list, name: name); | |||||
| return constant_op.constant(s_list, dtype: dtype, name: name); | |||||
| } | } | ||||
| public static bool is_constant(ITensorOrOperation tensor_or_op) | public static bool is_constant(ITensorOrOperation tensor_or_op) | ||||
| @@ -201,6 +201,7 @@ namespace Tensorflow | |||||
| TF_DataType.TF_STRING => "string", | TF_DataType.TF_STRING => "string", | ||||
| TF_DataType.TF_INT32 => "int32", | TF_DataType.TF_INT32 => "int32", | ||||
| TF_DataType.TF_FLOAT => "float32", | TF_DataType.TF_FLOAT => "float32", | ||||
| TF_DataType.TF_BOOL => "bool", | |||||
| _ => type.ToString() | _ => type.ToString() | ||||
| }; | }; | ||||
| @@ -0,0 +1,93 @@ | |||||
| using NumSharp; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class BaseResourceVariable : VariableV1 | |||||
| { | |||||
| protected string _handle_name; | |||||
| protected string handle_name => _handle_name; | |||||
| protected string _unique_id; | |||||
| public string unique_id => _unique_id; | |||||
| protected bool _in_graph_mode; | |||||
| protected bool _trainable; | |||||
| public bool trainable => _trainable; | |||||
| protected Tensor _initial_value; | |||||
| public Tensor initial_value => _initial_value; | |||||
| protected Tensor _parent_op; | |||||
| public Tensor parent_op => _parent_op; | |||||
| protected Tensor _handle; | |||||
| /// <summary> | |||||
| /// Variable handle | |||||
| /// </summary> | |||||
| public Tensor handle => _handle; | |||||
| protected TensorShape _shape; | |||||
| public TensorShape shape => _shape; | |||||
| public BaseResourceVariable() : base() | |||||
| { | |||||
| } | |||||
| public void __init__(bool trainable = true, | |||||
| Tensor handle = null, | |||||
| string name = null, | |||||
| string unique_id = null, | |||||
| string handle_name = null) | |||||
| { | |||||
| _trainable = trainable; | |||||
| _handle_name = handle_name + ":0"; | |||||
| _unique_id = unique_id; | |||||
| _handle = handle; | |||||
| _name = name; | |||||
| } | |||||
| public override BaseResourceVariable assign(object value, bool use_locking = false, string name = null, bool read_value = true) | |||||
| { | |||||
| var value_tensor = ops.convert_to_tensor(value, dtype: dtype); | |||||
| var assign_op = gen_resource_variable_ops.assign_variable_op( | |||||
| _handle, value_tensor, name: name); | |||||
| if (read_value) | |||||
| return _lazy_read(assign_op, value_tensor); | |||||
| return null; | |||||
| } | |||||
| public Tensor value() => _read_variable_op(); | |||||
| protected Tensor _read_variable_op() | |||||
| { | |||||
| var result = gen_resource_variable_ops.read_variable_op(_handle, _dtype); | |||||
| // _maybe_set_handle_data(_dtype, _handle, result); | |||||
| return result; | |||||
| } | |||||
| BaseResourceVariable _lazy_read(Operation op, Tensor value) | |||||
| { | |||||
| variable_accessed(this); | |||||
| return new _UnreadVariable(_handle, _dtype, _shape, _in_graph_mode, _unique_id); | |||||
| } | |||||
| /// <summary> | |||||
| /// Records that `variable` was accessed for the tape and FuncGraph. | |||||
| /// </summary> | |||||
| void variable_accessed(BaseResourceVariable variable) | |||||
| { | |||||
| if (variable.trainable) | |||||
| ; // tape.variable_accessed(variable) | |||||
| } | |||||
| public override string ToString() | |||||
| => $"tf.Variable '{name}' shape={shape} dtype={dtype.as_numpy_name()}, numpy={numpy()}"; | |||||
| public NDArray numpy() => _read_variable_op().numpy(); | |||||
| } | |||||
| } | |||||
| @@ -26,7 +26,6 @@ namespace Tensorflow | |||||
| { | { | ||||
| public bool _in_graph_mode = true; | public bool _in_graph_mode = true; | ||||
| public Tensor _initial_value; | public Tensor _initial_value; | ||||
| public string _graph_key; | |||||
| public bool _trainable; | public bool _trainable; | ||||
| public Tensor _snapshot; | public Tensor _snapshot; | ||||
| @@ -51,13 +50,7 @@ namespace Tensorflow | |||||
| string name = null, | string name = null, | ||||
| VariableDef variable_def = null, | VariableDef variable_def = null, | ||||
| TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| string import_scope = "") : base(initial_value, | |||||
| trainable, | |||||
| collections, | |||||
| validate_shape, | |||||
| caching_device, | |||||
| name, | |||||
| dtype) | |||||
| string import_scope = "") : base() | |||||
| { | { | ||||
| _in_graph_mode = true; | _in_graph_mode = true; | ||||
| @@ -13,14 +13,10 @@ | |||||
| } | } | ||||
| public static implicit operator Tensor(ResourceVariable var) | public static implicit operator Tensor(ResourceVariable var) | ||||
| { | |||||
| return null; | |||||
| } | |||||
| => var.handle; | |||||
| public static implicit operator ResourceVariable(Tensor var) | public static implicit operator ResourceVariable(Tensor var) | ||||
| { | |||||
| return null; | |||||
| } | |||||
| => var.ResourceVar; | |||||
| public static implicit operator RefVariable(ResourceVariable var) | public static implicit operator RefVariable(ResourceVariable var) | ||||
| { | { | ||||
| @@ -0,0 +1,63 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public partial class ResourceVariable | |||||
| { | |||||
| public static Tensor operator +(ResourceVariable x, int y) => op_helper("add", x, y); | |||||
| public static Tensor operator +(ResourceVariable x, float y) => op_helper("add", x, y); | |||||
| public static Tensor operator +(ResourceVariable x, double y) => op_helper("add", x, y); | |||||
| public static Tensor operator -(ResourceVariable x, int y) => op_helper("sub", x, y); | |||||
| public static Tensor operator -(ResourceVariable x, float y) => op_helper("sub", x, y); | |||||
| public static Tensor operator -(ResourceVariable x, double y) => op_helper("sub", x, y); | |||||
| public static Tensor operator -(ResourceVariable x, Tensor y) => op_helper("sub", x, y); | |||||
| public static Tensor operator <(ResourceVariable x, Tensor y) => gen_math_ops.less(x.value(), y); | |||||
| public static Tensor operator >(ResourceVariable x, Tensor y) => gen_math_ops.greater(x.value(), y); | |||||
| private static Tensor op_helper<T>(string default_name, ResourceVariable x, T y) | |||||
| => tf_with(ops.name_scope(null, default_name, new { x, y }), scope => | |||||
| { | |||||
| string name = scope; | |||||
| var xVal = x.value(); | |||||
| var yTensor = ops.convert_to_tensor(y, xVal.dtype.as_base_dtype(), "y"); | |||||
| Tensor result = null; | |||||
| switch (default_name) | |||||
| { | |||||
| case "add": | |||||
| result = x.dtype == TF_DataType.TF_STRING ? | |||||
| gen_math_ops.add(xVal, yTensor, name) : | |||||
| gen_math_ops.add_v2(xVal, yTensor, name); | |||||
| break; | |||||
| case "sub": | |||||
| result = gen_math_ops.sub(xVal, yTensor, name); | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| x.assign(result); | |||||
| result.ResourceVar = x; | |||||
| return result; | |||||
| }); | |||||
| } | |||||
| } | |||||
| @@ -15,6 +15,7 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using Google.Protobuf; | using Google.Protobuf; | ||||
| using NumSharp; | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -24,25 +25,14 @@ namespace Tensorflow | |||||
| /// <summary> | /// <summary> | ||||
| /// Variable based on resource handles. | /// Variable based on resource handles. | ||||
| /// </summary> | /// </summary> | ||||
| public partial class ResourceVariable : VariableV1 | |||||
| public partial class ResourceVariable : BaseResourceVariable | |||||
| { | { | ||||
| bool _in_graph_mode; | |||||
| Tensor _handle; | |||||
| TensorShape _shape; | |||||
| public TensorShape shape => _shape; | |||||
| string _handle_name; | |||||
| string _unique_id; | |||||
| public override string name => _handle_name; | |||||
| Operation _initializer_op; | Operation _initializer_op; | ||||
| public override Operation initializer => _initializer_op; | public override Operation initializer => _initializer_op; | ||||
| Tensor _initial_value; | |||||
| bool _trainable; | |||||
| public bool tranable => _trainable; | |||||
| Tensor _cached_value; | Tensor _cached_value; | ||||
| Tensor _graph_element; | Tensor _graph_element; | ||||
| public override Tensor graph_element => _graph_element; | public override Tensor graph_element => _graph_element; | ||||
| TF_DataType _dtype; | |||||
| public TF_DataType dtype => _dtype; | |||||
| public override string name => _handle.name; | |||||
| public string Device => _handle.Device; | public string Device => _handle.Device; | ||||
| public Graph Graph => _handle.graph; | public Graph Graph => _handle.graph; | ||||
| public override Operation op => _handle.op; | public override Operation op => _handle.op; | ||||
| @@ -56,13 +46,7 @@ namespace Tensorflow | |||||
| VariableDef variable_def = null, | VariableDef variable_def = null, | ||||
| TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| string import_scope = "", | string import_scope = "", | ||||
| TensorShape shape = null) : base(initial_value, | |||||
| trainable, | |||||
| collections, | |||||
| validate_shape, | |||||
| caching_device, | |||||
| name, | |||||
| dtype) | |||||
| TensorShape shape = null) : base() | |||||
| { | { | ||||
| if (variable_def != null) | if (variable_def != null) | ||||
| { | { | ||||
| @@ -80,6 +64,8 @@ namespace Tensorflow | |||||
| dtype: dtype, | dtype: dtype, | ||||
| shape: shape); | shape: shape); | ||||
| } | } | ||||
| _handle.ResourceVar = this; | |||||
| } | } | ||||
| private void _init_from_args(object initial_value = null, | private void _init_from_args(object initial_value = null, | ||||
| @@ -130,10 +116,8 @@ namespace Tensorflow | |||||
| shared_name: shared_name, | shared_name: shared_name, | ||||
| name: name, | name: name, | ||||
| graph_mode: _in_graph_mode); | graph_mode: _in_graph_mode); | ||||
| _unique_id = unique_id; | |||||
| _handle_name = handle_name + ":0"; | |||||
| _dtype = _initial_value.dtype.as_base_dtype(); | _dtype = _initial_value.dtype.as_base_dtype(); | ||||
| // _constraint = constraint; | |||||
| if (_in_graph_mode) | if (_in_graph_mode) | ||||
| { | { | ||||
| @@ -160,19 +144,22 @@ namespace Tensorflow | |||||
| var value = _read_variable_op(); | var value = _read_variable_op(); | ||||
| _graph_element = value; | _graph_element = value; | ||||
| }); | }); | ||||
| ops.add_to_collections(collections, this); | |||||
| } | |||||
| else | |||||
| { | |||||
| gen_resource_variable_ops.assign_variable_op(_handle, _initial_value); | |||||
| } | } | ||||
| ops.add_to_collections(collections, this); | |||||
| base.__init__(trainable: trainable, | |||||
| handle: _handle, | |||||
| name: name, | |||||
| unique_id: unique_id, | |||||
| handle_name: handle_name); | |||||
| }); | }); | ||||
| } | } | ||||
| private Tensor _read_variable_op() | |||||
| { | |||||
| var result = gen_resource_variable_ops.read_variable_op(_handle, _dtype); | |||||
| // _maybe_set_handle_data(_dtype, _handle, result); | |||||
| return result; | |||||
| } | |||||
| private void _init_from_proto(VariableDef variable_def, string import_scope = null) | private void _init_from_proto(VariableDef variable_def, string import_scope = null) | ||||
| { | { | ||||
| _in_graph_mode = true; | _in_graph_mode = true; | ||||
| @@ -184,8 +171,7 @@ namespace Tensorflow | |||||
| var prepend_name_scope = ops.prepend_name_scope(variable_def.VariableName, import_scope: import_scope); | var prepend_name_scope = ops.prepend_name_scope(variable_def.VariableName, import_scope: import_scope); | ||||
| _handle = g.as_graph_element(prepend_name_scope) as Tensor; | _handle = g.as_graph_element(prepend_name_scope) as Tensor; | ||||
| _shape = new TensorShape(_handle.op.get_attr("shape") as TensorShapeProto); | _shape = new TensorShape(_handle.op.get_attr("shape") as TensorShapeProto); | ||||
| _handle_name = _handle.name; | |||||
| _unique_id = _handle_name; | |||||
| prepend_name_scope = ops.prepend_name_scope(variable_def.InitializerName, import_scope: import_scope); | prepend_name_scope = ops.prepend_name_scope(variable_def.InitializerName, import_scope: import_scope); | ||||
| _initializer_op = g.as_graph_element(prepend_name_scope) as Operation; | _initializer_op = g.as_graph_element(prepend_name_scope) as Operation; | ||||
| if (!string.IsNullOrEmpty(variable_def.InitialValueName)) | if (!string.IsNullOrEmpty(variable_def.InitialValueName)) | ||||
| @@ -235,10 +221,5 @@ namespace Tensorflow | |||||
| return array_ops.identity(value); | return array_ops.identity(value); | ||||
| }); | }); | ||||
| } | } | ||||
| public override string ToString() | |||||
| { | |||||
| return $"tf.ResourceVariable '{name}' shape={shape} dtype={dtype}"; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -31,6 +31,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public abstract class VariableV1 | public abstract class VariableV1 | ||||
| { | { | ||||
| protected string _name; | |||||
| public virtual string name { get; } | public virtual string name { get; } | ||||
| public virtual Tensor graph_element { get; } | public virtual Tensor graph_element { get; } | ||||
| public virtual Operation op { get; } | public virtual Operation op { get; } | ||||
| @@ -41,13 +42,10 @@ namespace Tensorflow | |||||
| public Tensor _is_initialized_op { get; set; } | public Tensor _is_initialized_op { get; set; } | ||||
| public VariableV1(object initial_value = null, | |||||
| bool trainable = true, | |||||
| List<string> collections = null, | |||||
| bool validate_shape = true, | |||||
| string caching_device = "", | |||||
| string name = null, | |||||
| TF_DataType dtype = TF_DataType.DtInvalid) | |||||
| protected TF_DataType _dtype; | |||||
| public TF_DataType dtype => _dtype; | |||||
| public VariableV1() | |||||
| { | { | ||||
| } | } | ||||
| @@ -57,12 +55,13 @@ namespace Tensorflow | |||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| } | } | ||||
| public virtual ITensorOrOperation assign(object value, bool use_locking = false, string name = null, bool read_value = true) | |||||
| public virtual BaseResourceVariable assign(object value, bool use_locking = false, string name = null, bool read_value = true) | |||||
| { | { | ||||
| var assign = gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name); | |||||
| throw new NotImplementedException(""); | |||||
| /*var assign = gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name); | |||||
| if (read_value) | if (read_value) | ||||
| return assign; | return assign; | ||||
| return assign.op; | |||||
| return assign.op;*/ | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,31 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Eager; | |||||
| namespace Tensorflow | |||||
| { | |||||
| /// <summary> | |||||
| /// Represents a future for a read of a variable. | |||||
| /// Pretends to be the tensor if anyone looks. | |||||
| /// </summary> | |||||
| public class _UnreadVariable : BaseResourceVariable | |||||
| { | |||||
| public override string name => _in_graph_mode ? _parent_op.name : "UnreadVariable"; | |||||
| public _UnreadVariable(Tensor handle, TF_DataType dtype, TensorShape shape, | |||||
| bool in_graph_mode, string unique_id) : base() | |||||
| { | |||||
| _dtype = dtype; | |||||
| _shape = shape; | |||||
| _handle = handle; | |||||
| _unique_id = unique_id; | |||||
| _in_graph_mode = in_graph_mode; | |||||
| if (handle is EagerTensor) | |||||
| _handle_name = ""; | |||||
| else | |||||
| _handle_name = handle.name; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -46,13 +46,13 @@ namespace Tensorflow | |||||
| public void __enter__() | public void __enter__() | ||||
| { | { | ||||
| _name = _name ?? _default_name; | |||||
| if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
| { | { | ||||
| (scope_name, old_scope_name) = enter_eager_name_scope(tf.context, _name); | (scope_name, old_scope_name) = enter_eager_name_scope(tf.context, _name); | ||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| _name = _name ?? _default_name; | |||||
| Graph g = null; | Graph g = null; | ||||
| if (_values is List<Tensor> vList) | if (_values is List<Tensor> vList) | ||||
| @@ -21,8 +21,8 @@ namespace Tensorflow | |||||
| { | { | ||||
| public partial class tensorflow : ITensorFlowObject | public partial class tensorflow : ITensorFlowObject | ||||
| { | { | ||||
| public TF_DataType @byte = TF_DataType.TF_UINT8; | |||||
| public TF_DataType @sbyte = TF_DataType.TF_INT8; | |||||
| public TF_DataType byte8 = TF_DataType.TF_UINT8; | |||||
| public TF_DataType int8 = TF_DataType.TF_INT8; | |||||
| public TF_DataType int16 = TF_DataType.TF_INT16; | public TF_DataType int16 = TF_DataType.TF_INT16; | ||||
| public TF_DataType int32 = TF_DataType.TF_INT32; | public TF_DataType int32 = TF_DataType.TF_INT32; | ||||
| public TF_DataType int64 = TF_DataType.TF_INT64; | public TF_DataType int64 = TF_DataType.TF_INT64; | ||||
| @@ -41,27 +41,21 @@ namespace Tensorflow | |||||
| _constructThreadingObjects(); | _constructThreadingObjects(); | ||||
| } | } | ||||
| public ResourceVariable Variable<T>(T data, | public ResourceVariable Variable<T>(T data, | ||||
| bool trainable = true, | bool trainable = true, | ||||
| bool validate_shape = true, | bool validate_shape = true, | ||||
| string name = null, | string name = null, | ||||
| TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| int[] shape = null) | int[] shape = null) | ||||
| { | |||||
| return new ResourceVariable(data, | |||||
| => new ResourceVariable(data, | |||||
| trainable: trainable, | trainable: trainable, | ||||
| validate_shape: validate_shape, | validate_shape: validate_shape, | ||||
| name: name, | name: name, | ||||
| dtype: dtype, | dtype: dtype, | ||||
| shape: shape); | shape: shape); | ||||
| } | |||||
| public unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null) | public unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null) | ||||
| { | |||||
| return gen_array_ops.placeholder(dtype, shape, name); | |||||
| } | |||||
| => gen_array_ops.placeholder(dtype, shape, name); | |||||
| public void enable_eager_execution() | public void enable_eager_execution() | ||||
| { | { | ||||
| @@ -72,9 +66,7 @@ namespace Tensorflow | |||||
| public string VERSION => c_api.StringPiece(c_api.TF_Version()); | public string VERSION => c_api.StringPiece(c_api.TF_Version()); | ||||
| public Session get_default_session() | public Session get_default_session() | ||||
| { | |||||
| return ops.get_default_session(); | |||||
| } | |||||
| => ops.get_default_session(); | |||||
| public Session Session() | public Session Session() | ||||
| { | { | ||||
| @@ -19,7 +19,7 @@ | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="BenchmarkDotNet" Version="0.12.0" /> | <PackageReference Include="BenchmarkDotNet" Version="0.12.0" /> | ||||
| <PackageReference Include="SciSharp.TensorFlow.Redist" Version="1.15.1" /> | |||||
| <PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.1.0" /> | |||||
| <PackageReference Include="TensorFlow.NET" Version="0.15.1" /> | <PackageReference Include="TensorFlow.NET" Version="0.15.1" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| @@ -64,15 +64,9 @@ Download [Bazel 0.29.1](https://github.com/bazelbuild/bazel/releases/tag/0.29.1) | |||||
| `pip install C:/tmp/tensorflow_pkg/tensorflow-1.15.0-cp36-cp36m-win_amd64.whl` | `pip install C:/tmp/tensorflow_pkg/tensorflow-1.15.0-cp36-cp36m-win_amd64.whl` | ||||
| ### Export more APIs | |||||
| ### Build specific version for tf.net | |||||
| Add more api to `c_api.h` | |||||
| ```c++ | |||||
| TF_CAPI_EXPORT extern void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input); | |||||
| TF_CAPI_EXPORT extern void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, TF_Status* status); | |||||
| TF_CAPI_EXPORT extern void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op); | |||||
| ``` | |||||
| https://github.com/SciSharp/tensorflow | |||||
| For Linux version, these APIs symbols should also be put into `tensorflow/c/version_script.lds` to be exported. | For Linux version, these APIs symbols should also be put into `tensorflow/c/version_script.lds` to be exported. | ||||
| Please refer to commit `https://github.com/SciSharp/tensorflow/commit/58122da06be3e7707500ad889dfd5c760a3e0424` | Please refer to commit `https://github.com/SciSharp/tensorflow/commit/58122da06be3e7707500ad889dfd5c760a3e0424` | ||||
| @@ -1,54 +0,0 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using Tensorflow; | |||||
| using static Tensorflow.Binding; | |||||
| namespace TensorFlowNET.UnitTest.Basics | |||||
| { | |||||
| [TestClass] | |||||
| public sealed class AssignTests | |||||
| { | |||||
| [Ignore("Not implemented")] | |||||
| [TestMethod] | |||||
| public void ShouldAssignVariable() | |||||
| { | |||||
| var raw_data = new[] { 1.0, 2.0, 8.0, -1.0, 0.0, 5.5, 6.0, 16.0 }; | |||||
| var expected = new[] { false, true, false, false, true, false, true }; | |||||
| var spike = tf.Variable(false); | |||||
| using (var sess = new Session()) | |||||
| { | |||||
| spike.initializer.run(session: sess); | |||||
| foreach (var i in range(1, 2)) | |||||
| { | |||||
| if (raw_data[i] - raw_data[i - 1] > 5d) | |||||
| { | |||||
| var updater = tf.assign(spike, tf.constant(true)); | |||||
| updater.eval(sess); | |||||
| } else | |||||
| { | |||||
| tf.assign(spike, tf.constant(true)).eval(sess); | |||||
| } | |||||
| Assert.AreEqual((bool) spike.eval(), expected[i - 1]); | |||||
| } | |||||
| } | |||||
| } | |||||
| [TestMethod] | |||||
| public void Bug397() | |||||
| { | |||||
| // fix bug https://github.com/SciSharp/TensorFlow.NET/issues/397 | |||||
| var W = tf.Variable(-1, name: "weight_" + 1, dtype: tf.float32); | |||||
| var init = tf.global_variables_initializer(); | |||||
| var reluEval = tf.nn.relu(W); | |||||
| var nonZero = tf.assign(W, reluEval); | |||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| sess.run(init); | |||||
| float result = nonZero.eval(); | |||||
| Assert.IsTrue(result == 0f); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,24 +0,0 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using Tensorflow; | |||||
| using static Tensorflow.Binding; | |||||
| namespace TensorFlowNET.UnitTest.Basics | |||||
| { | |||||
| [TestClass] | |||||
| public sealed class NegativeTests | |||||
| { | |||||
| [TestMethod] | |||||
| public void ShouldReturnNegative() | |||||
| { | |||||
| var x = tf.constant(new[,] { { 1, 2 } }); | |||||
| var neg_x = tf.negative(x); | |||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| var result = sess.run(neg_x); | |||||
| Assert.AreEqual(result[0][0], -1); | |||||
| Assert.AreEqual(result[0][1], -2); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,73 @@ | |||||
| using FluentAssertions; | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using NumSharp; | |||||
| using System.Linq; | |||||
| using Tensorflow; | |||||
| using static Tensorflow.Binding; | |||||
| namespace TensorFlowNET.UnitTest.Basics | |||||
| { | |||||
| [TestClass] | |||||
| public class VariableTest | |||||
| { | |||||
| [TestMethod] | |||||
| public void NewVariable() | |||||
| { | |||||
| var x = tf.Variable(10, name: "x"); | |||||
| Assert.AreEqual("x:0", x.name); | |||||
| Assert.AreEqual(0, x.shape.ndim); | |||||
| Assert.AreEqual(10, (int)x.numpy()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void StringVar() | |||||
| { | |||||
| var mammal1 = tf.Variable("Elephant", name: "var1", dtype: tf.@string); | |||||
| var mammal2 = tf.Variable("Tiger"); | |||||
| } | |||||
| [TestMethod] | |||||
| public void VarSum() | |||||
| { | |||||
| var x = tf.constant(3, name: "x"); | |||||
| var y = tf.Variable(x + 1, name: "y"); | |||||
| Assert.AreEqual(4, (int)y.numpy()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void Assign1() | |||||
| { | |||||
| var variable = tf.Variable(31, name: "tree"); | |||||
| var unread = variable.assign(12); | |||||
| Assert.AreEqual(12, (int)unread.numpy()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void Assign2() | |||||
| { | |||||
| var v1 = tf.Variable(10.0f, name: "v1"); | |||||
| var v2 = v1.assign(v1 + 1.0f); | |||||
| Assert.AreEqual(v1.numpy(), v2.numpy()); | |||||
| Assert.AreEqual(11f, (float)v1.numpy()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void Accumulation() | |||||
| { | |||||
| var x = tf.Variable(10, name: "x"); | |||||
| for (int i = 0; i < 5; i++) | |||||
| x = x + 1; | |||||
| Assert.AreEqual(15, (int)x.numpy()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void ShouldReturnNegative() | |||||
| { | |||||
| var x = tf.constant(new[,] { { 1, 2 } }); | |||||
| var neg_x = tf.negative(x); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new[] { 1, 2 }, neg_x.shape)); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new[] { -1, -2 }, neg_x.numpy().ToArray<int>())); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -93,59 +93,42 @@ namespace TensorFlowNET.UnitTest | |||||
| public void ZerosConst() | public void ZerosConst() | ||||
| { | { | ||||
| // small size | // small size | ||||
| var tensor = tf.zeros(new Shape(3, 2), TF_DataType.TF_INT32, "small"); | |||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| var result = sess.run(tensor); | |||||
| var tensor = tf.zeros(new Shape(3, 2), tf.int32, "small"); | |||||
| Assert.AreEqual(result.shape[0], 3); | |||||
| Assert.AreEqual(result.shape[1], 2); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, result.Data<int>())); | |||||
| } | |||||
| Assert.AreEqual(tensor.shape[0], 3); | |||||
| Assert.AreEqual(tensor.shape[1], 2); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, tensor.numpy().ToArray<int>())); | |||||
| // big size | // big size | ||||
| tensor = tf.zeros(new Shape(200, 100), TF_DataType.TF_INT32, "big"); | |||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| var result = sess.run(tensor); | |||||
| tensor = tf.zeros(new Shape(200, 100), tf.int32, "big"); | |||||
| Assert.AreEqual(result.shape[0], 200); | |||||
| Assert.AreEqual(result.shape[1], 100); | |||||
| Assert.AreEqual(tensor.shape[0], 200); | |||||
| Assert.AreEqual(tensor.shape[1], 100); | |||||
| var data = result.Data<int>(); | |||||
| Assert.AreEqual(0, data[0]); | |||||
| Assert.AreEqual(0, data[500]); | |||||
| Assert.AreEqual(0, data[result.size - 1]); | |||||
| } | |||||
| var data = tensor.numpy().ToArray<int>(); | |||||
| Assert.AreEqual(0, data[0]); | |||||
| Assert.AreEqual(0, data[500]); | |||||
| Assert.AreEqual(0, data[data.Length - 1]); | |||||
| } | } | ||||
| [TestMethod] | [TestMethod] | ||||
| public void OnesConst() | public void OnesConst() | ||||
| { | { | ||||
| var ones = tf.ones(new Shape(3, 2), TF_DataType.TF_DOUBLE, "ones"); | |||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| var result = sess.run(ones); | |||||
| Assert.AreEqual(result.shape[0], 3); | |||||
| Assert.AreEqual(result.shape[1], 2); | |||||
| Assert.IsTrue(new[] { 1, 1, 1, 1, 1, 1 }.SequenceEqual(result.Data<int>())); | |||||
| } | |||||
| var ones = tf.ones(new Shape(3, 2), tf.float32, "ones"); | |||||
| Assert.AreEqual(ones.dtype, tf.float32); | |||||
| Assert.AreEqual(ones.shape[0], 3); | |||||
| Assert.AreEqual(ones.shape[1], 2); | |||||
| Assert.IsTrue(new float[] { 1, 1, 1, 1, 1, 1 }.SequenceEqual(ones.numpy().ToArray<float>())); | |||||
| } | } | ||||
| [TestMethod] | [TestMethod] | ||||
| public void OnesToHalves() | public void OnesToHalves() | ||||
| { | { | ||||
| var ones = tf.ones(new Shape(3, 2), TF_DataType.TF_DOUBLE, "ones"); | |||||
| var ones = tf.ones(new Shape(3, 2), tf.float64, "ones"); | |||||
| var halfes = ones * 0.5; | var halfes = ones * 0.5; | ||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| var result = sess.run(halfes); | |||||
| Assert.AreEqual(result.shape[0], 3); | |||||
| Assert.AreEqual(result.shape[1], 2); | |||||
| Assert.IsTrue(new[] { .5, .5, .5, .5, .5, .5 }.SequenceEqual(result.Data<double>())); | |||||
| } | |||||
| Assert.AreEqual(halfes.shape[0], 3); | |||||
| Assert.AreEqual(halfes.shape[1], 2); | |||||
| Assert.IsTrue(new[] { .5, .5, .5, .5, .5, .5 }.SequenceEqual(halfes.numpy().ToArray<double>())); | |||||
| } | } | ||||
| [TestMethod] | [TestMethod] | ||||
| @@ -158,15 +141,10 @@ namespace TensorFlowNET.UnitTest | |||||
| }); | }); | ||||
| var tensor = tf.constant(nd); | var tensor = tf.constant(nd); | ||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| var result = sess.run(tensor); | |||||
| var data = result.Data<int>(); | |||||
| var data = tensor.numpy().ToArray<int>(); | |||||
| Assert.AreEqual(result.shape[0], 2); | |||||
| Assert.AreEqual(result.shape[1], 3); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 1, 2, 1, 3 }, data)); | |||||
| } | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3 }, tensor.shape)); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 1, 2, 1, 3 }, data)); | |||||
| } | } | ||||
| [TestMethod] | [TestMethod] | ||||
| @@ -176,11 +154,7 @@ namespace TensorFlowNET.UnitTest | |||||
| var b = tf.constant(2.0); | var b = tf.constant(2.0); | ||||
| var c = a * b; | var c = a * b; | ||||
| var sess = tf.Session(); | |||||
| double result = sess.run(c); | |||||
| sess.close(); | |||||
| Assert.AreEqual(6.0, result); | |||||
| Assert.AreEqual(6.0, (double)c); | |||||
| } | } | ||||
| [TestMethod] | [TestMethod] | ||||
| @@ -1,32 +0,0 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using Tensorflow; | |||||
| using static Tensorflow.Binding; | |||||
| namespace TensorFlowNET.UnitTest | |||||
| { | |||||
| [TestClass] | |||||
| public class ConsumersTest : CApiTest | |||||
| { | |||||
| [TestMethod] | |||||
| public void Constant() | |||||
| { | |||||
| var X = tf.placeholder(tf.float64); | |||||
| var W = tf.constant(1.0D); | |||||
| var mul = tf.multiply(X, W); | |||||
| EXPECT_EQ(1, X.op.OutputNumConsumers(0)); | |||||
| EXPECT_EQ(1, W.op.OutputNumConsumers(0)); | |||||
| } | |||||
| [TestMethod] | |||||
| public void Variable() | |||||
| { | |||||
| var X = tf.placeholder(tf.float64); | |||||
| var W = tf.Variable(1.0D, name: "var"); | |||||
| var mul = tf.multiply(X, W); | |||||
| EXPECT_EQ(1, X.op.OutputNumConsumers(0)); | |||||
| //EXPECT_EQ(1, W.op.OutputNumConsumers(0)); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -4,46 +4,51 @@ using System.Linq; | |||||
| using Tensorflow; | using Tensorflow; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest | |||||
| namespace TensorFlowNET.UnitTest.Gradient | |||||
| { | { | ||||
| [TestClass] | [TestClass] | ||||
| public class GradientTest | |||||
| public class GradientTapeTest | |||||
| { | { | ||||
| [TestMethod] | |||||
| public void GradientTape() | |||||
| { | |||||
| var x = tf.ones((2, 2)); | |||||
| using (var t = tf.GradientTape()) | |||||
| { | |||||
| t.watch(x); | |||||
| } | |||||
| } | |||||
| [TestMethod] | [TestMethod] | ||||
| public void Gradients() | public void Gradients() | ||||
| { | { | ||||
| var graph = tf.Graph().as_default(); | |||||
| var a = tf.constant(0.0); | var a = tf.constant(0.0); | ||||
| var b = 2.0 * a; | var b = 2.0 * a; | ||||
| Assert.AreEqual(b.name, "mul:0"); | |||||
| Assert.AreEqual(b.op.inputs[0].name, "mul/x:0"); | |||||
| Assert.AreEqual(b.op.inputs[1].name, "Const:0"); | |||||
| //Assert.AreEqual(b.name, "mul:0"); | |||||
| //Assert.AreEqual(b.op.inputs[0].name, "mul/x:0"); | |||||
| //Assert.AreEqual(b.op.inputs[1].name, "Const:0"); | |||||
| var ys = a + b; | var ys = a + b; | ||||
| Assert.AreEqual(ys.name, "add:0"); | |||||
| Assert.AreEqual(ys.op.inputs[0].name, "Const:0"); | |||||
| Assert.AreEqual(ys.op.inputs[1].name, "mul:0"); | |||||
| //Assert.AreEqual(ys.name, "add:0"); | |||||
| //Assert.AreEqual(ys.op.inputs[0].name, "Const:0"); | |||||
| //Assert.AreEqual(ys.op.inputs[1].name, "mul:0"); | |||||
| var g = tf.gradients(ys, new Tensor[] { a, b }, stop_gradients: new Tensor[] { a, b }); | |||||
| Assert.AreEqual(g[0].name, "gradients/Fill:0"); | |||||
| Assert.AreEqual(g[1].name, "gradients/Fill:0"); | |||||
| //var g = tf.gradients(ys, new Tensor[] { a, b }, stop_gradients: new Tensor[] { a, b }); | |||||
| //Assert.AreEqual(g[0].name, "gradients/Fill:0"); | |||||
| //Assert.AreEqual(g[1].name, "gradients/Fill:0"); | |||||
| } | } | ||||
| [TestMethod] | [TestMethod] | ||||
| public void Gradient2x() | public void Gradient2x() | ||||
| { | { | ||||
| var graph = tf.Graph().as_default(); | |||||
| using (var sess = tf.Session(graph)) | |||||
| { | |||||
| var x = tf.constant(7.0f); | |||||
| var y = x * x * tf.constant(0.1f); | |||||
| var x = tf.constant(7.0f); | |||||
| var y = x * x * tf.constant(0.1f); | |||||
| var grad = tf.gradients(y, x); | |||||
| Assert.AreEqual(grad[0].name, "gradients/AddN:0"); | |||||
| //var grad = tf.gradients(y, x); | |||||
| //Assert.AreEqual(grad[0].name, "gradients/AddN:0"); | |||||
| float r = sess.run(grad[0]); | |||||
| Assert.AreEqual(r, 1.4f); | |||||
| } | |||||
| //float r = sess.run(grad[0]); | |||||
| //Assert.AreEqual(r, 1.4f); | |||||
| } | } | ||||
| [TestMethod] | [TestMethod] | ||||
| @@ -6,10 +6,10 @@ using NumSharp; | |||||
| using Tensorflow; | using Tensorflow; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest.gradients_test | |||||
| namespace TensorFlowNET.UnitTest.Gradient | |||||
| { | { | ||||
| [TestClass] | [TestClass] | ||||
| public class GradientsTest : PythonTest | |||||
| public class GradientTest : PythonTest | |||||
| { | { | ||||
| [TestMethod] | [TestMethod] | ||||
| public void BroadcastToGrad() | public void BroadcastToGrad() | ||||
| @@ -179,7 +179,7 @@ namespace TensorFlowNET.UnitTest | |||||
| public void Autocast_Case4() | public void Autocast_Case4() | ||||
| { | { | ||||
| var sess = tf.Session().as_default(); | var sess = tf.Session().as_default(); | ||||
| var input = tf.placeholder(tf.@byte, shape: new TensorShape(6)); | |||||
| var input = tf.placeholder(tf.byte8, shape: new TensorShape(6)); | |||||
| var op = tf.reshape(input, new int[] {2, 3}); | var op = tf.reshape(input, new int[] {2, 3}); | ||||
| sess.run(tf.global_variables_initializer()); | sess.run(tf.global_variables_initializer()); | ||||
| var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(NPTypeCode.Single) + 0.1f)); | var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(NPTypeCode.Single) + 0.1f)); | ||||
| @@ -267,11 +267,7 @@ namespace TensorFlowNET.UnitTest | |||||
| var tensor = new[] { 0, 1, 2, 3 }; | var tensor = new[] { 0, 1, 2, 3 }; | ||||
| var mask = np.array(new[] { true, false, true, false }); | var mask = np.array(new[] { true, false, true, false }); | ||||
| var masked = tf.boolean_mask(tensor, mask); | var masked = tf.boolean_mask(tensor, mask); | ||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| var result = sess.run(masked); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, result.ToArray<int>())); | |||||
| } | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, masked.ToArray<int>())); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,152 +0,0 @@ | |||||
| using FluentAssertions; | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using NumSharp; | |||||
| using Tensorflow; | |||||
| using static Tensorflow.Binding; | |||||
| namespace TensorFlowNET.UnitTest | |||||
| { | |||||
| [TestClass] | |||||
| public class VariableTest | |||||
| { | |||||
| [TestMethod] | |||||
| public void Initializer() | |||||
| { | |||||
| var x = tf.Variable(10, name: "x"); | |||||
| using (var session = tf.Session()) | |||||
| { | |||||
| session.run(x.initializer); | |||||
| var result = session.run(x); | |||||
| Assert.AreEqual(10, (int)result); | |||||
| } | |||||
| } | |||||
| [TestMethod] | |||||
| public void StringVar() | |||||
| { | |||||
| var mammal1 = tf.Variable("Elephant", name: "var1", dtype: tf.@string); | |||||
| var mammal2 = tf.Variable("Tiger"); | |||||
| } | |||||
| /// <summary> | |||||
| /// https://www.tensorflow.org/api_docs/python/tf/variable_scope | |||||
| /// how to create a new variable | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public void VarCreation() | |||||
| { | |||||
| tf.Graph().as_default(); | |||||
| tf_with(tf.variable_scope("foo"), delegate | |||||
| { | |||||
| tf_with(tf.variable_scope("bar"), delegate | |||||
| { | |||||
| var v = tf.get_variable("v", new TensorShape(1)); | |||||
| v.name.Should().Be("foo/bar/v:0"); | |||||
| }); | |||||
| }); | |||||
| } | |||||
| /// <summary> | |||||
| /// how to reenter a premade variable scope safely | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public void ReenterVariableScope() | |||||
| { | |||||
| tf.Graph().as_default(); | |||||
| variable_scope vs = null; | |||||
| tf_with(tf.variable_scope("foo"), v => vs = v); | |||||
| // Re-enter the variable scope. | |||||
| tf_with(tf.variable_scope(vs, auxiliary_name_scope: false), v => | |||||
| { | |||||
| var vs1 = (VariableScope)v; | |||||
| // Restore the original name_scope. | |||||
| tf_with(tf.name_scope(vs1.original_name_scope), delegate | |||||
| { | |||||
| var v1 = tf.get_variable("v", new TensorShape(1)); | |||||
| Assert.AreEqual(v1.name, "foo/v:0"); | |||||
| var c1 = tf.constant(new int[] { 1 }, name: "c"); | |||||
| Assert.AreEqual(c1.name, "foo/c:0"); | |||||
| }); | |||||
| }); | |||||
| } | |||||
| [TestMethod] | |||||
| public void ScalarVar() | |||||
| { | |||||
| var x = tf.constant(3, name: "x"); | |||||
| var y = tf.Variable(x + 1, name: "y"); | |||||
| var model = tf.global_variables_initializer(); | |||||
| using (var session = tf.Session()) | |||||
| { | |||||
| session.run(model); | |||||
| int result = session.run(y); | |||||
| Assert.AreEqual(result, 4); | |||||
| } | |||||
| } | |||||
| [TestMethod] | |||||
| public void Assign1() | |||||
| { | |||||
| var graph = tf.Graph().as_default(); | |||||
| var variable = tf.Variable(31, name: "tree"); | |||||
| var init = tf.global_variables_initializer(); | |||||
| var sess = tf.Session(graph); | |||||
| sess.run(init); | |||||
| NDArray result = sess.run(variable); | |||||
| Assert.IsTrue((int)result == 31); | |||||
| var assign = variable.assign(12); | |||||
| result = sess.run(assign); | |||||
| Assert.IsTrue((int)result == 12); | |||||
| } | |||||
| [TestMethod] | |||||
| public void Assign2() | |||||
| { | |||||
| var v1 = tf.Variable(10.0f, name: "v1"); //tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer); | |||||
| var inc_v1 = v1.assign((RefVariable)v1 + 1.0f); | |||||
| // Add an op to initialize the variables. | |||||
| var init_op = tf.global_variables_initializer(); | |||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| sess.run(init_op); | |||||
| // o some work with the model. | |||||
| inc_v1.op.run(session: sess); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// https://databricks.com/tensorflow/variables | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public void Add() | |||||
| { | |||||
| tf.Graph().as_default(); | |||||
| int result = 0; | |||||
| Tensor x = tf.Variable(10, name: "x"); | |||||
| var init_op = tf.global_variables_initializer(); | |||||
| using (var session = tf.Session()) | |||||
| { | |||||
| session.run(init_op); | |||||
| for(int i = 0; i < 5; i++) | |||||
| { | |||||
| x = x + 1; | |||||
| result = session.run(x); | |||||
| print(result); | |||||
| } | |||||
| } | |||||
| Assert.AreEqual(15, result); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -11,7 +11,7 @@ namespace TensorFlowNET.UnitTest | |||||
| public void GetVersion() | public void GetVersion() | ||||
| { | { | ||||
| var ver = tf.VERSION; | var ver = tf.VERSION; | ||||
| Assert.IsTrue(ver.StartsWith("1.15.")); | |||||
| Assert.IsTrue(ver.StartsWith("2.")); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||