| @@ -1,6 +1,7 @@ | |||||
| using MethodBoundaryAspect.Fody.Attributes; | using MethodBoundaryAspect.Fody.Attributes; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Diagnostics; | |||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using Tensorflow.Functions; | using Tensorflow.Functions; | ||||
| @@ -8,6 +9,7 @@ using static Tensorflow.Binding; | |||||
| namespace Tensorflow.NumPy | namespace Tensorflow.NumPy | ||||
| { | { | ||||
| [DebuggerStepThrough] | |||||
| public sealed class AutoNumPyAttribute : OnMethodBoundaryAspect | public sealed class AutoNumPyAttribute : OnMethodBoundaryAspect | ||||
| { | { | ||||
| bool _changedMode = false; | bool _changedMode = false; | ||||
| @@ -10,7 +10,7 @@ namespace Tensorflow.NumPy | |||||
| { | { | ||||
| [AutoNumPy] | [AutoNumPy] | ||||
| public static NDArray argmax(NDArray a, Axis axis = null) | public static NDArray argmax(NDArray a, Axis axis = null) | ||||
| => new NDArray(math_ops.argmax(a, axis)); | |||||
| => new NDArray(math_ops.argmax(a, axis ?? 0)); | |||||
| [AutoNumPy] | [AutoNumPy] | ||||
| public static NDArray argsort(NDArray a, Axis axis = null) | public static NDArray argsort(NDArray a, Axis axis = null) | ||||
| @@ -31,7 +31,7 @@ namespace Tensorflow.NumPy | |||||
| public NDArray(IntPtr address, Shape shape, TF_DataType dtype) | public NDArray(IntPtr address, Shape shape, TF_DataType dtype) | ||||
| : base(address, shape, dtype) { NewEagerTensorHandle(); } | : base(address, shape, dtype) { NewEagerTensorHandle(); } | ||||
| public NDArray(Tensor tensor, bool eval = true) : base(tensor.Handle) | |||||
| public NDArray(Tensor tensor, bool clone = false) : base(tensor.Handle, clone: clone) | |||||
| { | { | ||||
| if (_handle is null) | if (_handle is null) | ||||
| { | { | ||||
| @@ -35,7 +35,7 @@ namespace Tensorflow | |||||
| tf.Status.Check(true); | tf.Status.Check(true); | ||||
| return num; | return num; | ||||
| } | } | ||||
| public int NumInputs => c_api.TF_OperationNumInputs(_handle); | |||||
| public int NumInputs => _handle == IntPtr.Zero ? -1 : c_api.TF_OperationNumInputs(_handle); | |||||
| private TF_DataType[] _input_types => _inputs_val._inputs.Select(x => x.dtype).ToArray(); | private TF_DataType[] _input_types => _inputs_val._inputs.Select(x => x.dtype).ToArray(); | ||||
| protected InputList _inputs_val; | protected InputList _inputs_val; | ||||
| @@ -23,7 +23,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| public partial class Operation | public partial class Operation | ||||
| { | { | ||||
| public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); | |||||
| public int NumOutputs => _handle == IntPtr.Zero ? -1 : c_api.TF_OperationNumOutputs(_handle); | |||||
| public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(_tf_output(index)); | public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(_tf_output(index)); | ||||
| public int OutputListLength(string name) | public int OutputListLength(string name) | ||||
| @@ -38,7 +38,7 @@ namespace Tensorflow | |||||
| public virtual Tensor[] outputs => _outputs; | public virtual Tensor[] outputs => _outputs; | ||||
| public Tensor output => _outputs.FirstOrDefault(); | public Tensor output => _outputs.FirstOrDefault(); | ||||
| public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | |||||
| public int NumControlOutputs => _handle == IntPtr.Zero ? -1 : c_api.TF_OperationNumControlOutputs(_handle); | |||||
| public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); | public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); | ||||
| @@ -39,9 +39,12 @@ namespace Tensorflow | |||||
| /// Create a Tensor object from an existing TF handle | /// Create a Tensor object from an existing TF handle | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="handle">Handle to a <see cref="Tensor"/> object.</param> | /// <param name="handle">Handle to a <see cref="Tensor"/> object.</param> | ||||
| public Tensor(SafeTensorHandle handle) | |||||
| public unsafe Tensor(SafeTensorHandle handle, bool clone = false) | |||||
| { | { | ||||
| _handle = handle; | _handle = handle; | ||||
| if (clone) | |||||
| _handle = TF_NewTensor(shape, dtype, data: TensorDataPointer.ToPointer()); | |||||
| isCreatedInGraphMode = !tf.executing_eagerly(); | isCreatedInGraphMode = !tf.executing_eagerly(); | ||||
| } | } | ||||
| @@ -55,7 +55,7 @@ namespace Tensorflow | |||||
| return new NDArray(str, shape); | return new NDArray(str, shape); | ||||
| } | } | ||||
| return new NDArray(this); | |||||
| return new NDArray(this, clone: true); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||