|
|
@@ -16,6 +16,7 @@ |
|
|
|
|
|
|
|
|
using NumSharp; |
|
|
using NumSharp; |
|
|
using System; |
|
|
using System; |
|
|
|
|
|
using System.Collections.Generic; |
|
|
using System.Linq; |
|
|
using System.Linq; |
|
|
using System.Numerics; |
|
|
using System.Numerics; |
|
|
using static Tensorflow.Binding; |
|
|
using static Tensorflow.Binding; |
|
|
@@ -25,7 +26,7 @@ namespace Tensorflow |
|
|
public partial class Tensor |
|
|
public partial class Tensor |
|
|
{ |
|
|
{ |
|
|
#if _REGEN |
|
|
#if _REGEN |
|
|
#region Compute |
|
|
|
|
|
|
|
|
#region Compute |
|
|
%operators = ["add", "sub", "mul", "div", "mod"] |
|
|
%operators = ["add", "sub", "mul", "div", "mod"] |
|
|
%operators_sign = ["+", "-", "*", "/", "%"] |
|
|
%operators_sign = ["+", "-", "*", "/", "%"] |
|
|
%operators_comparers = [">", "<", ">=", "<="] |
|
|
%operators_comparers = [">", "<", ">=", "<="] |
|
|
@@ -49,11 +50,11 @@ namespace Tensorflow |
|
|
% |
|
|
% |
|
|
% |
|
|
% |
|
|
public static Tensor operator -(Tensor x) => gen_math_ops.neg(x); |
|
|
public static Tensor operator -(Tensor x) => gen_math_ops.neg(x); |
|
|
#endregion |
|
|
|
|
|
|
|
|
#endregion |
|
|
#else |
|
|
#else |
|
|
#region Compute |
|
|
|
|
|
|
|
|
#region Compute |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static Tensor operator +(Tensor lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); |
|
|
public static Tensor operator +(Tensor lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); |
|
|
public static Tensor operator +(Tensor lhs, NDArray rhs) => BinaryOpWrapper("add", lhs, rhs); |
|
|
public static Tensor operator +(Tensor lhs, NDArray rhs) => BinaryOpWrapper("add", lhs, rhs); |
|
|
public static Tensor operator +(NDArray lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); |
|
|
public static Tensor operator +(NDArray lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); |
|
|
@@ -129,31 +130,31 @@ namespace Tensorflow |
|
|
public static Tensor operator *(double lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); |
|
|
public static Tensor operator *(double lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); |
|
|
public static Tensor operator *(Tensor lhs, Complex rhs) => BinaryOpWrapper("mul", lhs, rhs); |
|
|
public static Tensor operator *(Tensor lhs, Complex rhs) => BinaryOpWrapper("mul", lhs, rhs); |
|
|
public static Tensor operator *(Complex lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); |
|
|
public static Tensor operator *(Complex lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); |
|
|
public static Tensor operator /(Tensor lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(Tensor lhs, NDArray rhs) => BinaryOpWrapper("truediv", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(NDArray lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(Tensor lhs, sbyte rhs) => BinaryOpWrapper("truediv", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(sbyte lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(Tensor lhs, byte rhs) => BinaryOpWrapper("truediv", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(byte lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(Tensor lhs, short rhs) => BinaryOpWrapper("truediv", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(short lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(Tensor lhs, ushort rhs) => BinaryOpWrapper("truediv", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(ushort lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(Tensor lhs, int rhs) => BinaryOpWrapper("truediv", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(int lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(Tensor lhs, uint rhs) => BinaryOpWrapper("truediv", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(uint lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(Tensor lhs, ulong rhs) => BinaryOpWrapper("truediv", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(ulong lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(Tensor lhs, long rhs) => BinaryOpWrapper("truediv", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(long lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(Tensor lhs, float rhs) => BinaryOpWrapper("truediv", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(float lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(Tensor lhs, double rhs) => BinaryOpWrapper("truediv", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(double lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(Tensor lhs, Complex rhs) => BinaryOpWrapper("truediv", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(Complex lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs); |
|
|
|
|
|
|
|
|
public static Tensor operator /(Tensor lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(Tensor lhs, NDArray rhs) => BinaryOpWrapper("div", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(NDArray lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(Tensor lhs, sbyte rhs) => BinaryOpWrapper("div", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(sbyte lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(Tensor lhs, byte rhs) => BinaryOpWrapper("div", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(byte lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(Tensor lhs, short rhs) => BinaryOpWrapper("div", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(short lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(Tensor lhs, ushort rhs) => BinaryOpWrapper("div", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(ushort lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(Tensor lhs, int rhs) => BinaryOpWrapper("div", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(int lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(Tensor lhs, uint rhs) => BinaryOpWrapper("div", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(uint lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(Tensor lhs, ulong rhs) => BinaryOpWrapper("div", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(ulong lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(Tensor lhs, long rhs) => BinaryOpWrapper("div", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(long lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(Tensor lhs, float rhs) => BinaryOpWrapper("div", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(float lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(Tensor lhs, double rhs) => BinaryOpWrapper("div", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(double lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(Tensor lhs, Complex rhs) => BinaryOpWrapper("div", lhs, rhs); |
|
|
|
|
|
public static Tensor operator /(Complex lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); |
|
|
public static Tensor operator %(Tensor lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); |
|
|
public static Tensor operator %(Tensor lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); |
|
|
public static Tensor operator %(Tensor lhs, NDArray rhs) => BinaryOpWrapper("mod", lhs, rhs); |
|
|
public static Tensor operator %(Tensor lhs, NDArray rhs) => BinaryOpWrapper("mod", lhs, rhs); |
|
|
public static Tensor operator %(NDArray lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); |
|
|
public static Tensor operator %(NDArray lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); |
|
|
@@ -281,24 +282,43 @@ namespace Tensorflow |
|
|
public static Tensor operator <=(Tensor lhs, Complex rhs) => gen_math_ops.less_equal(lhs, rhs); |
|
|
public static Tensor operator <=(Tensor lhs, Complex rhs) => gen_math_ops.less_equal(lhs, rhs); |
|
|
public static Tensor operator <=(Complex lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); |
|
|
public static Tensor operator <=(Complex lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); |
|
|
public static Tensor operator -(Tensor x) => gen_math_ops.neg(x); |
|
|
public static Tensor operator -(Tensor x) => gen_math_ops.neg(x); |
|
|
#endregion |
|
|
|
|
|
|
|
|
#endregion |
|
|
#endif |
|
|
#endif |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private static readonly TF_DataType[] _intTfDataTypes = { |
|
|
private static readonly TF_DataType[] _intTfDataTypes = { |
|
|
TF_DataType.TF_INT8, TF_DataType.TF_INT16, TF_DataType.TF_INT32, TF_DataType.TF_INT64, |
|
|
TF_DataType.TF_INT8, TF_DataType.TF_INT16, TF_DataType.TF_INT32, TF_DataType.TF_INT64, |
|
|
TF_DataType.TF_QINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QINT32, |
|
|
TF_DataType.TF_QINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QINT32, |
|
|
TF_DataType.TF_UINT8, TF_DataType.TF_UINT16, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64 |
|
|
TF_DataType.TF_UINT8, TF_DataType.TF_UINT16, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64 |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
private static string div_or_truediv<Tx, Ty>(string name, Tx x, Ty y) |
|
|
|
|
|
{ |
|
|
|
|
|
bool is_floating = false; |
|
|
|
|
|
var types = new List<bool>(); |
|
|
|
|
|
|
|
|
|
|
|
if (x is Tensor t1) |
|
|
|
|
|
types.add(t1.dtype.is_floating()); |
|
|
|
|
|
|
|
|
|
|
|
if (y is Tensor t2) |
|
|
|
|
|
types.add(t2.dtype.is_floating()); |
|
|
|
|
|
|
|
|
|
|
|
is_floating = types.Contains(true); |
|
|
|
|
|
|
|
|
|
|
|
return is_floating ? "truediv" : name; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
private static Tensor BinaryOpWrapper<Tx, Ty>(string name, Tx x, Ty y) |
|
|
private static Tensor BinaryOpWrapper<Tx, Ty>(string name, Tx x, Ty y) |
|
|
{ |
|
|
{ |
|
|
TF_DataType dtype = TF_DataType.DtInvalid; |
|
|
TF_DataType dtype = TF_DataType.DtInvalid; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (x is Tensor tl) |
|
|
if (x is Tensor tl) |
|
|
dtype = tl.dtype.as_base_dtype(); |
|
|
dtype = tl.dtype.as_base_dtype(); |
|
|
if (y is Tensor tr) |
|
|
if (y is Tensor tr) |
|
|
dtype = tr.dtype.as_base_dtype(); |
|
|
dtype = tr.dtype.as_base_dtype(); |
|
|
|
|
|
|
|
|
|
|
|
if (name == "div") |
|
|
|
|
|
name = div_or_truediv(name, x, y); |
|
|
|
|
|
|
|
|
return tf_with(ops.name_scope(null, name, new { x, y }), scope => |
|
|
return tf_with(ops.name_scope(null, name, new { x, y }), scope => |
|
|
{ |
|
|
{ |
|
|
Tensor result; |
|
|
Tensor result; |
|
|
|