From 42bffbcaf1864969b8c994ef2b17e7e11ec83f71 Mon Sep 17 00:00:00 2001 From: Precreator <1689487228@qq.com> Date: Thu, 26 Jun 2025 17:48:19 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90=E4=BA=86tensor.cc=E7=9A=84?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E4=BF=AE=E6=94=B9=EF=BC=8C=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E4=BA=86=E5=BC=A0=E9=87=8F=E7=9A=84=E8=BD=AC=E7=BD=AE=E5=92=8C?= =?UTF-8?q?=E6=B1=82=E4=B8=80=E4=B8=AA=E5=BC=A0=E9=87=8F=E5=9C=A8=E6=8C=87?= =?UTF-8?q?=E5=AE=9A=E7=BB=B4=E5=BA=A6=E4=B8=8A=E7=9A=84=E6=9C=80=E5=A4=A7?= =?UTF-8?q?=E5=80=BC=E7=9A=84=E7=B4=A2=E5=BC=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cc/tensor/tensor.cc | 55 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 53 insertions(+), 2 deletions(-) diff --git a/cc/tensor/tensor.cc b/cc/tensor/tensor.cc index bedf5ff..43cbff4 100644 --- a/cc/tensor/tensor.cc +++ b/cc/tensor/tensor.cc @@ -16,11 +16,17 @@ std::shared_ptr Tensor::transpose() { std::vector transposed_data(size); // 请在这里写转置的代码 - + for(std::size_t i=0;i(new_shape); + return std::make_shared(new_shape,transposed_data); } @@ -50,6 +56,51 @@ std::shared_ptr argmax(const std::shared_ptr& tensor, int axis) auto result = std::make_shared(output_shape); // 这个问题似乎有点难,所以我们决定给你送点分。一个简单的办法是分axis为0还是为1来进行讨论,反正我们已经把问题简化为了,在一个二维的tensor里面,找到每一行或者每一列的最大值,并输出一个一维的tensor。 // 补全这里的代码。 + size_t rows=tensor->shape[0]; + size_t cols=tensor->shape[1]; + if(axis==0) + { + output_shape={1,cols}; + } + else if(axis==1) + { + output_shape={rows,1}; + } + if(axis==0)//qiu lie xiang liang + { + for(size_t j=0;jdata[i*cols+j]>maxx) + { + maxx=tensor->data[i*cols+j]; + maxx_id=i; + } + } + result->data[j] = static_cast(maxx_id); + } + } + else + { + for(size_t i=0;idata[i*cols+j]) + { + maxx=tensor->data[i*cols+j]; + maxx_id=j; + } + } + result->data[i] = static_cast(maxx_id); + } + } + return result; }