You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

param_pack.cpp 1.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. #include "megdnn/oprs/general.h"
  2. #include "src/common/utils.h"
  3. using namespace megdnn;
  4. void ParamPackConcatSplitBase::check_exec(
  5. const TensorLayout& concated, const TensorLayout& offsets,
  6. const TensorLayout& parts) {
  7. megdnn_assert(
  8. offsets.dtype == dtype::Int32{}, "bad dtype: %s", offsets.dtype.name());
  9. megdnn_assert(
  10. concated.ndim == 1 && offsets.ndim == 1 && parts.ndim == 1 &&
  11. concated.stride[0] == 1 && offsets.stride[0] == 1 &&
  12. parts.stride[0] == 1,
  13. "bad layout: concated=%s offsets=%s parts=%s", concated.to_string().c_str(),
  14. offsets.to_string().c_str(), parts.to_string().c_str());
  15. }
  16. std::vector<dt_int32> ParamPackConcatSplitBase::gen_offsets(
  17. const TensorShapeArray& shapes, size_t alignment, size_t dtype_size) {
  18. megdnn_assert(
  19. alignment && (alignment & (alignment - 1)) == 0,
  20. "alignment must be power of 2: %zu", alignment);
  21. if (alignment < dtype_size)
  22. alignment = dtype_size;
  23. megdnn_assert(
  24. alignment % dtype_size == 0,
  25. "alignment must be multiple of dtype size: %zu vs %zu", alignment,
  26. dtype_size);
  27. alignment /= dtype_size;
  28. auto get_aligned = [alignment](size_t v) {
  29. auto mod = v & (alignment - 1);
  30. return v + ((alignment - mod) & (alignment - 1));
  31. };
  32. std::vector<dt_int32> offsets(shapes.size() << 1);
  33. size_t offset = 0;
  34. for (size_t i = 0; i < shapes.size(); i++) {
  35. offset = get_aligned(offset);
  36. offsets[i << 1] = offset;
  37. offset += shapes[i].total_nr_elems();
  38. offsets[(i << 1) + 1] = offset;
  39. }
  40. return offsets;
  41. }
  42. // vim: syntax=cpp.doxygen