Browse Source

pnnx handle index_put with empty indices and scalar values (#5288)

tags/20240410
nihui GitHub 2 years ago
parent
commit
7ed252c854
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 87 additions and 20 deletions
  1. +43
    -17
      tools/pnnx/src/ir.cpp
  2. +12
    -2
      tools/pnnx/src/pass_level3/fuse_expression.cpp
  3. +29
    -1
      tools/pnnx/src/pass_level5/eval_expression.cpp
  4. +3
    -0
      tools/pnnx/tests/test_Tensor_index_put.py

+ 43
- 17
tools/pnnx/src/ir.cpp View File

@@ -227,7 +227,7 @@ Parameter::Parameter(const torch::jit::Node* value_node)
{
at::Tensor t = value_node->t(torch::jit::attr::value);

if (t.dim() == 0)
if (t.dim() == 0 && t.numel() == 1)
{
if (t.scalar_type() == c10::ScalarType::Long)
{
@@ -1810,29 +1810,48 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)
}
}

if (is_running_mean_var)
{
fprintf(pyfp, " self.%s_%s = self.load_pnnx_bin_as_tensor(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str(), op->name.c_str(), key.c_str());
}
else
{
fprintf(pyfp, " self.%s_%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str(), op->name.c_str(), key.c_str());
}

bool is_empty = false;
for (size_t i = 0; i < attr.shape.size(); i++)
{
fprintf(pyfp, "%d", attr.shape[i]);
if (i + 1 != attr.shape.size())
fprintf(pyfp, ",");
if (attr.shape[i] == 0)
is_empty = true;
}

if (attr.type == 1 || attr.type == 2 || attr.type == 3)
if (is_empty)
{
fprintf(pyfp, "), '%s')\n", type_to_numpy_string(attr.type));
fprintf(pyfp, " self.%s_%s = torch.from_numpy(np.empty((", sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str());

for (size_t i = 0; i < attr.shape.size(); i++)
{
fprintf(pyfp, "%d,", attr.shape[i]);
}

fprintf(pyfp, "), dtype='%s'))\n", type_to_numpy_string(attr.type));
}
else
{
fprintf(pyfp, "), '%s', requires_grad=False)\n", type_to_numpy_string(attr.type));
if (is_running_mean_var)
{
fprintf(pyfp, " self.%s_%s = self.load_pnnx_bin_as_tensor(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str(), op->name.c_str(), key.c_str());
}
else
{
fprintf(pyfp, " self.%s_%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str(), op->name.c_str(), key.c_str());
}

for (size_t i = 0; i < attr.shape.size(); i++)
{
fprintf(pyfp, "%d,", attr.shape[i]);
}

if (attr.type == 1 || attr.type == 2 || attr.type == 3)
{
fprintf(pyfp, "), '%s')\n", type_to_numpy_string(attr.type));
}
else
{
fprintf(pyfp, "), '%s', requires_grad=False)\n", type_to_numpy_string(attr.type));
}
}
}

@@ -2320,7 +2339,14 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)
}
if (param.type == 3)
{
fprintf(pyfp, "%f", param.f);
if (op->type == "Tensor.index_put" && it.first == "values")
{
fprintf(pyfp, "torch.tensor(%f)", param.f);
}
else
{
fprintf(pyfp, "%f", param.f);
}
}
if (param.type == 4)
{


+ 12
- 2
tools/pnnx/src/pass_level3/fuse_expression.cpp View File

@@ -62,7 +62,7 @@ static bool operand_maybe_tensor(const Operand* operand)
return operand_maybe_tensor(op->inputs[0]);
}

if (op->type == "aten::to" || op->type == "aten::detach")
if (op->type == "Tensor.to" || op->type == "aten::detach")
{
return operand_maybe_tensor(op->inputs[0]);
}
@@ -494,6 +494,15 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s
{
fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip);
}
else if (!operand_maybe_tensor(operand))
{
std::string dtype = op->params.at("dtype").s;

// torch.xxx
expr += dtype + "(";
fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip);
expr += ")";
}
else
{
goto DEFAULT;
@@ -712,7 +721,8 @@ void fuse_expression(Graph& graph, const std::set<std::string>& foldable_constan
{
// fuse noop type cast only
bool noop_to = (op->outputs[0]->type != -1) && (op->inputs[0]->type == op->outputs[0]->type);
need_fuse = noop_to;
bool is_scalar = !operand_maybe_tensor(op->outputs[0]);
need_fuse = noop_to || is_scalar;
}
if (op->type == "aten::detach" || op->type == "aten::ScalarImplicit")
{


+ 29
- 1
tools/pnnx/src/pass_level5/eval_expression.cpp View File

@@ -181,7 +181,10 @@ static std::string eval_expression(const Operator* op)
|| t == "square"
|| t == "tan"
|| t == "tanh"
|| t == "trunc")
|| t == "trunc"
|| t == "torch.bool"
|| t == "torch.float"
|| t == "torch.long")
{
std::string a = exprstack.top();
exprstack.pop();
@@ -334,6 +337,31 @@ static std::string eval_expression(const Operator* op)
float r = trunc(af);
exprstack.push(std::to_string(r));
}
if (t == "torch.bool")
{
int r = int(af);
if (token_is_interger_literal(a))
{
r = std::stoi(a);
}

exprstack.push(r == 0 ? "False" : "True");
}
if (t == "torch.float")
{
float r = af;
exprstack.push(std::to_string(r));
}
if (t == "torch.long")
{
long r = long(af);
if (token_is_interger_literal(a))
{
r = std::stol(a);
}

exprstack.push(std::to_string(r));
}
}
else
{


+ 3
- 0
tools/pnnx/tests/test_Tensor_index_put.py View File

@@ -25,6 +25,9 @@ class Model(nn.Module):
z = z.clone()
x = x.index_put(indices=[torch.tensor([10,2])], values=y, accumulate=False)
z.index_put_(indices=[torch.tensor([1,0,0]), torch.tensor([3,2,1])], values=w, accumulate=True)

x[torch.tensor([1], dtype=torch.int64)] = torch.tensor(45).float()
x[torch.tensor([], dtype=torch.int64)] = torch.tensor(233).float()
return x, z

def test():


Loading…
Cancel
Save