diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs index b9308f18..ae14958f 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs @@ -17,110 +17,283 @@ using NumSharp; using System; using System.Linq; +using System.Numerics; using static Tensorflow.Binding; namespace Tensorflow { public partial class Tensor { - public static Tensor operator +(double x, Tensor y) => BinaryOpWrapper("add", x, y); - public static Tensor operator +(float x, Tensor y) => BinaryOpWrapper("add", x, y); - public static Tensor operator +(int x, Tensor y) => BinaryOpWrapper("add", x, y); - public static Tensor operator +(Tensor x, Tensor y) => BinaryOpWrapper("add", x, y); - public static Tensor operator +(Tensor x, int y) => BinaryOpWrapper("add", x, y); - public static Tensor operator +(Tensor x, float y) => BinaryOpWrapper("add", x, y); - public static Tensor operator +(Tensor x, double y) => BinaryOpWrapper("add", x, y); +#if _REGEN + #region Compute + %operators = ["add", "sub", "mul", "div", "mod"] + %operators_sign = ["+", "-", "*", "/", "%"] + %operators_comparers = [">", "<", ">=", "<="] + %operators_comparers_names = ["greater", "less", "greater_equal", "less_equal"] - public static Tensor operator -(Tensor t1) => gen_math_ops.neg(t1); + %possabilities = ["NDArray", "sbyte", "byte", "short", "ushort", "int", "uint", "ulong", "long", "float", "double", "Complex"] + + %foreach operators, operators_sign% + public static Tensor operator #2(Tensor lhs, Tensor rhs) => BinaryOpWrapper("#1", lhs, rhs); + %foreach possabilities% + public static Tensor operator #2(Tensor lhs, #101 rhs) => BinaryOpWrapper("#1", lhs, rhs); + public static Tensor operator #2(#101 lhs, Tensor rhs) => BinaryOpWrapper("#1", lhs, rhs); + % + % - public static Tensor operator -(double x, Tensor y) => BinaryOpWrapper("sub", x, y); - public static Tensor operator -(float x, Tensor y) => BinaryOpWrapper("sub", x, y); - public static Tensor operator -(int x, Tensor y) => BinaryOpWrapper("sub", x, y); - public static Tensor operator -(Tensor x, Tensor y) => BinaryOpWrapper("sub", x, y); - public static Tensor operator -(Tensor x, int y) => BinaryOpWrapper("sub", x, y); - public static Tensor operator -(Tensor x, float y) => BinaryOpWrapper("sub", x, y); - public static Tensor operator -(Tensor x, double y) => BinaryOpWrapper("sub", x, y); + %foreach operators_comparers_names, operators_comparers % + public static Tensor operator #2(Tensor lhs, Tensor rhs) => gen_math_ops.#1(lhs, rhs); + %foreach possabilities% + public static Tensor operator #2(Tensor lhs, #101 rhs) => gen_math_ops.#1(lhs, rhs); + public static Tensor operator #2(#101 lhs, Tensor rhs) => gen_math_ops.#1(lhs, rhs); + % + % + public static Tensor operator -(Tensor x) => gen_math_ops.neg(x); + #endregion +#else + #region Compute - public static Tensor operator *(float x, Tensor y) => BinaryOpWrapper("mul", x, y); - public static Tensor operator *(double x, Tensor y) => BinaryOpWrapper("mul", x, y); - public static Tensor operator *(Tensor x, Tensor y) => BinaryOpWrapper("mul", x, y); - public static Tensor operator *(Tensor x, int y) => BinaryOpWrapper("mul", x, y); - public static Tensor operator *(Tensor tensor, bool constant) => BinaryOpWrapper("mul", tensor, constant); - public static Tensor operator *(Tensor tensor, sbyte constant) => BinaryOpWrapper("mul", tensor, constant); - public static Tensor operator *(Tensor tensor, byte constant) => BinaryOpWrapper("mul", tensor, constant); - public static Tensor operator *(Tensor tensor, ushort constant) => BinaryOpWrapper("mul", tensor, constant); - public static Tensor operator *(Tensor tensor, short constant) => BinaryOpWrapper("mul", tensor, constant); - public static Tensor operator *(Tensor tensor, uint constant) => BinaryOpWrapper("mul", tensor, constant); - public static Tensor operator *(Tensor tensor, long constant) => BinaryOpWrapper("mul", tensor, constant); - public static Tensor operator *(Tensor tensor, ulong constant) => BinaryOpWrapper("mul", tensor, constant); - public static Tensor operator *(Tensor tensor, float constant) => BinaryOpWrapper("mul", tensor, constant); - public static Tensor operator *(Tensor tensor, double constant) => BinaryOpWrapper("mul", tensor, constant); - public static Tensor operator *(bool constant, Tensor tensor) => BinaryOpWrapper("mul", constant, tensor); - public static Tensor operator *(sbyte constant, Tensor tensor) => BinaryOpWrapper("mul", constant, tensor); - public static Tensor operator *(byte constant, Tensor tensor) => BinaryOpWrapper("mul", constant, tensor); - public static Tensor operator *(ushort constant, Tensor tensor) => BinaryOpWrapper("mul", constant, tensor); - public static Tensor operator *(short constant, Tensor tensor) => BinaryOpWrapper("mul", constant, tensor); - public static Tensor operator *(int constant, Tensor tensor) => BinaryOpWrapper("mul", constant, tensor); - public static Tensor operator *(uint constant, Tensor tensor) => BinaryOpWrapper("mul", constant, tensor); - public static Tensor operator *(long constant, Tensor tensor) => BinaryOpWrapper("mul", constant, tensor); - public static Tensor operator *(ulong constant, Tensor tensor) => BinaryOpWrapper("mul", constant, tensor); - public static Tensor operator *(NDArray x, Tensor y) => BinaryOpWrapper("mul", x, y); - public static Tensor operator *(Tensor x, NDArray y) => BinaryOpWrapper("mul", x, y); + + public static Tensor operator +(Tensor lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, NDArray rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(NDArray lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, sbyte rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(sbyte lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, byte rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(byte lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, short rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(short lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, ushort rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(ushort lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, int rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(int lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, uint rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(uint lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, ulong rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(ulong lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, long rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(long lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, float rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(float lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, double rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(double lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, Complex rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Complex lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator -(Tensor lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, NDArray rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(NDArray lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, sbyte rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(sbyte lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, byte rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(byte lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, short rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(short lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, ushort rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(ushort lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, int rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(int lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, uint rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(uint lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, ulong rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(ulong lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, long rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(long lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, float rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(float lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, double rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(double lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, Complex rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Complex lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator *(Tensor lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, NDArray rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(NDArray lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, sbyte rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(sbyte lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, byte rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(byte lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, short rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(short lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, ushort rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(ushort lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, int rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(int lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, uint rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(uint lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, ulong rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(ulong lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, long rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(long lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, float rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(float lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, double rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(double lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, Complex rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Complex lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator /(Tensor lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, NDArray rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(NDArray lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, sbyte rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(sbyte lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, byte rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(byte lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, short rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(short lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, ushort rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(ushort lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, int rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(int lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, uint rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(uint lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, ulong rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(ulong lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, long rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(long lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, float rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(float lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, double rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(double lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, Complex rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Complex lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator %(Tensor lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, NDArray rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(NDArray lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, sbyte rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(sbyte lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, byte rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(byte lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, short rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(short lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, ushort rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(ushort lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, int rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(int lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, uint rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(uint lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, ulong rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(ulong lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, long rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(long lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, float rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(float lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, double rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(double lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, Complex rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Complex lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator >(Tensor lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(Tensor lhs, NDArray rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(NDArray lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(Tensor lhs, sbyte rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(sbyte lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(Tensor lhs, byte rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(byte lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(Tensor lhs, short rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(short lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(Tensor lhs, ushort rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(ushort lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(Tensor lhs, int rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(int lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(Tensor lhs, uint rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(uint lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(Tensor lhs, ulong rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(ulong lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(Tensor lhs, long rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(long lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(Tensor lhs, float rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(float lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(Tensor lhs, double rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(double lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(Tensor lhs, Complex rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(Complex lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator <(Tensor lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(Tensor lhs, NDArray rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(NDArray lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(Tensor lhs, sbyte rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(sbyte lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(Tensor lhs, byte rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(byte lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(Tensor lhs, short rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(short lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(Tensor lhs, ushort rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(ushort lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(Tensor lhs, int rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(int lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(Tensor lhs, uint rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(uint lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(Tensor lhs, ulong rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(ulong lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(Tensor lhs, long rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(long lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(Tensor lhs, float rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(float lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(Tensor lhs, double rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(double lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(Tensor lhs, Complex rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(Complex lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator >=(Tensor lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(Tensor lhs, NDArray rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(NDArray lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(Tensor lhs, sbyte rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(sbyte lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(Tensor lhs, byte rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(byte lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(Tensor lhs, short rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(short lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(Tensor lhs, ushort rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(ushort lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(Tensor lhs, int rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(int lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(Tensor lhs, uint rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(uint lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(Tensor lhs, ulong rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(ulong lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(Tensor lhs, long rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(long lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(Tensor lhs, float rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(float lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(Tensor lhs, double rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(double lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(Tensor lhs, Complex rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(Complex lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator <=(Tensor lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(Tensor lhs, NDArray rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(NDArray lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(Tensor lhs, sbyte rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(sbyte lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(Tensor lhs, byte rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(byte lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(Tensor lhs, short rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(short lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(Tensor lhs, ushort rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(ushort lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(Tensor lhs, int rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(int lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(Tensor lhs, uint rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(uint lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(Tensor lhs, ulong rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(ulong lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(Tensor lhs, long rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(long lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(Tensor lhs, float rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(float lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(Tensor lhs, double rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(double lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(Tensor lhs, Complex rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(Complex lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator -(Tensor x) => gen_math_ops.neg(x); + #endregion +#endif + private static readonly TF_DataType[] _intTfDataTypes = { TF_DataType.TF_INT8, TF_DataType.TF_INT16, TF_DataType.TF_INT32, TF_DataType.TF_INT64, TF_DataType.TF_QINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QINT32, TF_DataType.TF_UINT8, TF_DataType.TF_UINT16, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64 }; - public static Tensor operator /(double x, Tensor y) => BinaryOpWrapper("truediv", x, y); - public static Tensor operator /(float x, Tensor y) => BinaryOpWrapper("truediv", x, y); - public static Tensor operator /(int x, Tensor y) => BinaryOpWrapper("floordiv", x, y); - public static Tensor operator /(Tensor x, Tensor y) => - _intTfDataTypes.Contains(x.dtype) - ? BinaryOpWrapper("floordiv", x, y) - : BinaryOpWrapper("truediv", x, y); - public static Tensor operator /(Tensor x, int y) => BinaryOpWrapper("floordiv", x, y); - public static Tensor operator /(Tensor x, float y) => BinaryOpWrapper("truediv", x, y); - public static Tensor operator /(Tensor x, double y) => BinaryOpWrapper("truediv", x, y); - - public static Tensor operator %(Tensor x, Tensor y) => BinaryOpWrapper("mod", x, y); - - public static Tensor operator >(double x, Tensor y) => gen_math_ops.greater(x, y); - public static Tensor operator >(float x, Tensor y) => gen_math_ops.greater(x, y); - public static Tensor operator >(int x, Tensor y) => gen_math_ops.greater(x, y); - public static Tensor operator >(Tensor x, Tensor y) => gen_math_ops.greater(x, y); - public static Tensor operator >(Tensor x, int y) => gen_math_ops.greater(x, y); - public static Tensor operator >(Tensor x, float y) => gen_math_ops.greater(x, y); - public static Tensor operator >(Tensor x, double y) => gen_math_ops.greater(x, y); - - public static Tensor operator <(double x, Tensor y) => gen_math_ops.less(x, y); - public static Tensor operator <(float x, Tensor y) => gen_math_ops.less(x, y); - public static Tensor operator <(int x, Tensor y) => gen_math_ops.less(x, y); - public static Tensor operator <(Tensor x, Tensor y) => gen_math_ops.less(x, y); - public static Tensor operator <(Tensor x, int y) => gen_math_ops.less(x, y); - public static Tensor operator <(Tensor x, float y) => gen_math_ops.less(x, y); - public static Tensor operator <(Tensor x, double y) => gen_math_ops.less(x, y); - - public static Tensor operator >=(double x, Tensor y) => gen_math_ops.greater_equal(x, y); - public static Tensor operator >=(float x, Tensor y) => gen_math_ops.greater_equal(x, y); - public static Tensor operator >=(int x, Tensor y) => gen_math_ops.greater_equal(x, y); - public static Tensor operator >=(Tensor x, Tensor y) => gen_math_ops.greater_equal(x, y); - public static Tensor operator >=(Tensor x, int y) => gen_math_ops.greater_equal(x, y); - public static Tensor operator >=(Tensor x, float y) => gen_math_ops.greater_equal(x, y); - public static Tensor operator >=(Tensor x, double y) => gen_math_ops.greater_equal(x, y); - - public static Tensor operator <=(int x, Tensor y) => gen_math_ops.less_equal(x, y); - public static Tensor operator <=(float x, Tensor y) => gen_math_ops.less_equal(x, y); - public static Tensor operator <=(double x, Tensor y) => gen_math_ops.less_equal(x, y); - public static Tensor operator <=(Tensor x, Tensor y) => gen_math_ops.less_equal(x, y); - public static Tensor operator <=(Tensor x, int y) => gen_math_ops.less_equal(x, y); - public static Tensor operator <=(Tensor x, float y) => gen_math_ops.less_equal(x, y); - public static Tensor operator <=(Tensor x, double y) => gen_math_ops.less_equal(x, y); - private static Tensor BinaryOpWrapper(string name, Tx x, Ty y) { TF_DataType dtype = TF_DataType.DtInvalid; + if (x is Tensor tl) dtype = tl.dtype.as_base_dtype(); if (y is Tensor tr) @@ -128,15 +301,20 @@ namespace Tensorflow return tf_with(ops.name_scope(null, name, new { x, y }), scope => { - Tensor result = null; + Tensor result; var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x"); var y1 = ops.convert_to_tensor(y, dtype: dtype, name: "y"); - switch (name.ToLower()) + switch (name.ToLowerInvariant()) { case "add": result = gen_math_ops.add(x1, y1, name: scope); break; + case "div": + result = _intTfDataTypes.Contains(x1.dtype) || _intTfDataTypes.Contains(y1.dtype) + ? gen_math_ops.floor_div(x1, y1, name: scope) + : gen_math_ops.real_div(x1, y1, name: scope); + break; case "floordiv": result = gen_math_ops.floor_div(x1, y1, name: scope); break; @@ -153,7 +331,7 @@ namespace Tensorflow result = gen_math_ops.floor_mod(x1, y1, name: scope); break; default: - throw new NotImplementedException($"BinaryOpWrapper: {name} - {typeof(Tx).Name}, {typeof(Ty)}"); + throw new NotImplementedException($"BinaryOpWrapper: {name} - {typeof(Tx).Name}, {typeof(Ty).Name}"); } return result;