Browse Source

feat(imperative/python): add transformation manager

GitOrigin-RevId: a3c1732ffd
tags/v1.8.0
Megvii Engine Team 4 years ago
parent
commit
d3689c3f3c
1 changed files with 61 additions and 0 deletions
  1. +61
    -0
      imperative/python/src/transformation.h

+ 61
- 0
imperative/python/src/transformation.h View File

@@ -0,0 +1,61 @@
/**
* \file imperative/python/src/transformation.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include "megbrain/imperative/transformation.h"

namespace mgb::imperative::python {
struct TransformationManager {
enum Segment {
ModuleTrace,
Grad,
Scalar,
Trace,
Eval,
};

std::array<std::vector<std::shared_ptr<Transformation>>, 5> segments;

template <Segment segment>
void register_at(std::shared_ptr<Transformation> transformation) {
mgb_assert(segment < segments.size());
std::shared_ptr<Transformation> next;
for (size_t i = segment; i < segments.size(); ++i) {
if (!segments[i].empty()) {
next = segments[i].back();
break;
}
}
if (!next) {
transformation->register_at(Transformation::bottom());
} else {
transformation->register_at(next->pos());
}
segments[segment].push_back(transformation);
}

template <Segment segment>
void unregister(std::shared_ptr<Transformation> transformation) noexcept {
mgb_assert(segment < segments.size());
auto iter = std::find(
segments[segment].begin(), segments[segment].end(), transformation);
mgb_assert(iter != segments[segment].end());
transformation->unregister();
segments[segment].erase(iter);
}

static TransformationManager& get_instance() {
static TransformationManager sl_instance;
return sl_instance;
}
};
} // namespace mgb::imperative::python

Loading…
Cancel
Save