|
|
|
@@ -150,6 +150,7 @@ class BroadcastIterator { |
|
|
|
public:
|
|
|
|
BroadcastIterator(std::vector<size_t> input_shape_a, std::vector<size_t> input_shape_b,
|
|
|
|
std::vector<size_t> output_shape);
|
|
|
|
virtual ~BroadcastIterator() = default;
|
|
|
|
inline size_t GetInputPosA() const { return input_pos_[0]; }
|
|
|
|
inline size_t GetInputPosB() const { return input_pos_[1]; }
|
|
|
|
void SetPos(size_t pos);
|
|
|
|
@@ -174,6 +175,7 @@ class BroadcastIterator { |
|
|
|
class TransposeIterator {
|
|
|
|
public:
|
|
|
|
TransposeIterator(std::vector<size_t> output_shape, std::vector<size_t> axes, const std::vector<size_t> &input_shape);
|
|
|
|
virtual ~TransposeIterator() = default;
|
|
|
|
inline size_t GetPos() const { return pos_; }
|
|
|
|
void SetPos(size_t pos);
|
|
|
|
void GenNextPos();
|
|
|
|
|