diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index c19ecae7..2490d6a1 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -67,7 +67,16 @@ namespace Tensorflow if (NDim < 0 || other.NDim < 0) return new TensorShape(); else - return new TensorShape(NDim + other.NDim); + { + var concatenate_dims = new int[NDim + other.NDim]; + for (int i = 0; i < NDim; i++) + concatenate_dims[i] = dims[i]; + + for (int i = 0; i < other.NDim; i++) + concatenate_dims[NDim + i] = other.dims[i]; + + return new TensorShape(concatenate_dims); + } } public static implicit operator TensorShape(int[] dims) => new TensorShape(dims);