You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

uctc.cc 6.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. #include <pybind11/pybind11.h>
  2. #include <pybind11/stl.h>
  3. #include "math/arith.h"
  4. #include "operators/nn.h"
  5. #include "tensor/tensor.h"
  6. #include "operators/ops.h"
  7. #include "operators/autodiff.h"
  8. namespace py = pybind11;
  9. PYBIND11_MODULE(uctc, m) {
  10. py::module C = m.def_submodule("C", "C module");
  11. py::module arith = C.def_submodule("arith", "Arithmetic module");
  12. arith.def("sqrt", &arith::sqrt, "Square root function", py::arg("x") = 0.0);
  13. py::class_<tensor::Tensor, std::shared_ptr<tensor::Tensor>>(m, "Tensor")
  14. .def_readonly("shape", &tensor::Tensor::shape)
  15. .def_readonly("size", &tensor::Tensor::size)
  16. .def("data", &tensor::Tensor::get_data, "Get the data of the tensor", pybind11::return_value_policy::copy)
  17. .def("transpose", &tensor::Tensor::transpose, "Transpose the tensor", pybind11::return_value_policy::copy);
  18. py::module nn = m.def_submodule("nn", "Neural network module");
  19. py::class_<nn::Node, std::shared_ptr<nn::Node>>(nn, "Node")
  20. .def("data", &nn::Node::get_data, "Get the data of the node", pybind11::return_value_policy::copy)
  21. .def("tensor", &nn::Node::get_tensor, "Get the tensor of the node", pybind11::return_value_policy::automatic_reference);
  22. py::class_<nn::DataNode, nn::Node, std::shared_ptr<nn::DataNode>>(nn, "DataNode");
  23. py::class_<nn::Parameter, nn::DataNode, std::shared_ptr<nn::Parameter>>(nn, "Parameter")
  24. .def(pybind11::init<py::array_t<float>>(), "Create a parameter from an array.")
  25. .def("update", &nn::Parameter::update, "Update the parameter node", py::arg("grad") = nullptr, py::arg("learning_rate") = 0.001);
  26. py::class_<nn::Constant, nn::DataNode, std::shared_ptr<nn::Constant>>(nn, "Constant")
  27. .def(pybind11::init<py::array_t<float>>(), "Create a constant node from a numpy array");
  28. py::class_<nn::FunctionNode, nn::Node, std::shared_ptr<nn::FunctionNode>>(nn, "FunctionNode");
  29. py::class_<nn::Add, nn::FunctionNode, std::shared_ptr<nn::Add>>(nn, "Add")
  30. .def(py::init<std::shared_ptr<nn::Node>, std::shared_ptr<nn::Node>>(), "Create an add function node")
  31. .def("forward", &nn::Add::forward, "Forward function");
  32. py::class_<nn::AddBias, nn::FunctionNode, std::shared_ptr<nn::AddBias>>(nn, "AddBias")
  33. .def(py::init<std::shared_ptr<nn::Node>, std::shared_ptr<nn::Node>>(), "Create an add bias function node")
  34. .def("forward", &nn::AddBias::forward, "Forward function")
  35. .def("data", &nn::AddBias::get_data, "Get the data of the node", pybind11::return_value_policy::automatic_reference);
  36. py::class_<nn::Linear, nn::FunctionNode, std::shared_ptr<nn::Linear>>(nn, "Linear")
  37. .def(py::init<std::shared_ptr<nn::Node>, std::shared_ptr<nn::Node>>(), "Create a linear function node")
  38. .def("forward", &nn::Linear::forward, "Forward function");
  39. py::class_<nn::ReLU, nn::FunctionNode, std::shared_ptr<nn::ReLU>>(nn, "ReLU")
  40. .def(py::init<std::shared_ptr<nn::Node>>(), "Create a ReLU function node");
  41. py::class_<nn::Loss, nn::FunctionNode, std::shared_ptr<nn::Loss>>(nn, "Loss");
  42. py::class_<nn::SquareLoss, nn::Loss, std::shared_ptr<nn::SquareLoss>>(nn, "SquareLoss")
  43. .def(py::init<std::shared_ptr<nn::Node>, std::shared_ptr<nn::Node>>(), "Create a square loss function node");
  44. py::class_<nn::SoftmaxLoss, nn::Loss, std::shared_ptr<nn::SoftmaxLoss>>(nn, "SoftmaxLoss")
  45. .def(py::init<std::shared_ptr<nn::Node>, std::shared_ptr<nn::Node>>(), "Create a softmax loss function node");
  46. nn.def("log_softmax", &nn::log_softmax, "Log softmax function", py::arg("logits"));
  47. nn.def("gradients", &nn::gradients, "Calculate the gradients", py::arg("loss") = nullptr, py::arg("nodes") = std::vector<std::shared_ptr<nn::Node>>{});
  48. nn.def("pyarray_to_tensor", &tensor::pyarray_to_tensor, "Convert a numpy array to a tensor", py::arg("arr"));
  49. nn.def("argmax", &tensor::argmax, "Get a tensor's argmax", py::arg("tensor"), py::arg("axis"));
  50. nn.def("mean", &tensor::mean, "Get a tensor element's mean value", py::arg("tensor"));
  51. nn.def("exp", &tensor::exp, "Get exp of a tensor", py::arg("tensor"));
  52. // framework test
  53. py::module framework = m.def_submodule("framework", "Framework module");
  54. py::module basis = framework.def_submodule("basis", "Basic modules");
  55. // task 1
  56. basis.def("mul", &operators::mul<int>, "Multiply two integers", py::arg("a"), py::arg("b"));
  57. basis.def("id", &operators::id<int>, "Identity function", py::arg("a"));
  58. basis.def("add", &operators::add<int>, "Add two integers", py::arg("a"), py::arg("b"));
  59. basis.def("neg", &operators::neg<int>, "Negate an integer", py::arg("a"));
  60. basis.def("lt", &operators::lt<int>, "Less than operator", py::arg("a"), py::arg("b"));
  61. basis.def("eq", &operators::eq<int>, "Equal operator", py::arg("a"), py::arg("b"));
  62. basis.def("max", &operators::max<int>, "Max operator", py::arg("a"), py::arg("b"));
  63. // task 2
  64. basis.def("is_close", &operators::is_close, "Check if two floats are close", py::arg("x"), py::arg("y"));
  65. basis.def("sigmoid", &operators::sigmoid, "Sigmoid function", py::arg("x"));
  66. basis.def("relu", &operators::relu, "ReLU function", py::arg("x"));
  67. basis.def("inv", &operators::inv, "Inverse function", py::arg("x"));
  68. basis.def("inv_back", &operators::inv_back, "Inv back function", py::arg("x"), py::arg("d"));
  69. basis.def("relu_back", &operators::relu_back, "ReLU back function", py::arg("x"), py::arg("d"));
  70. // task 3
  71. basis.def("negList", &operators::negList, "Negate a list of integers", py::arg("lst"));
  72. // task 4, 5
  73. basis.def("addLists", &operators::addLists, "Add two lists of integers", py::arg("lst1"), py::arg("lst2"));
  74. // task 6
  75. basis.def("sumList", &operators::sumList, "Sum a list of integers", py::arg("lst"));
  76. // task 7
  77. basis.def("prodList", &operators::prodList, "Multiply a list of integers", py::arg("lst"));
  78. py::module autodiff = framework.def_submodule("autodiff", "Autodiff modules");
  79. autodiff.def("test_central_difference", &autodiff::test_central_difference, "Test central difference");
  80. autodiff.def("test_addscalar", &autodiff::test_addscalar, "Test add scalar");
  81. autodiff.def("test_mulscalar", &autodiff::test_mulscalar, "Test mul scalar");
  82. autodiff.def("test_logscalar", &autodiff::test_logscalar, "Test log scalar");
  83. autodiff.def("test_invscalar", &autodiff::test_invscalar, "Test inv scalar");
  84. autodiff.def("test_sigmoidscalar", &autodiff::test_sigmoidscalar, "Test sigmoid scalar");
  85. }

计算机大作业