Skip to content

Commit

Permalink
Draft support for multi-threading
Browse files Browse the repository at this point in the history
  • Loading branch information
victorreijgwart committed Oct 24, 2024
1 parent c352f73 commit fea5914
Showing 1 changed file with 38 additions and 16 deletions.
54 changes: 38 additions & 16 deletions library/python/src/raycast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
#include <wavemap/core/map/hashed_wavelet_octree.h>
#include <wavemap/core/utils/iterate/grid_iterator.h>
#include <wavemap/core/utils/query/query_accelerator.h>
#include <wavemap/core/utils/thread_pool.h>

#include "wavemap/core/utils/iterate/ray_iterator.h"

using namespace nb::literals; // NOLINT

namespace wavemap {
FloatingPoint raycast(const HashedWaveletOctree& map, Point3D start_point,
Point3D end_point, FloatingPoint threshold) {
FloatingPoint raycast(const HashedWaveletOctree& map,
const Point3D& start_point, const Point3D& end_point,
FloatingPoint threshold) {
const FloatingPoint mcw = map.getMinCellWidth();
const Ray ray(start_point, end_point, mcw);
for (const Index3D& ray_voxel_index : ray) {
Expand All @@ -29,8 +31,8 @@ FloatingPoint raycast(const HashedWaveletOctree& map, Point3D start_point,

FloatingPoint raycast_fast(
QueryAccelerator<HashedWaveletOctree>& query_accelerator,
Point3D start_point, Point3D end_point, FloatingPoint threshold,
FloatingPoint min_cell_width) {
const Point3D& start_point, const Point3D& end_point,
FloatingPoint threshold, FloatingPoint min_cell_width) {
const Ray ray(start_point, end_point, min_cell_width);
for (const Index3D& ray_voxel_index : ray) {
if (query_accelerator.getCellValue(ray_voxel_index) > threshold) {
Expand Down Expand Up @@ -72,25 +74,45 @@ void add_raycast_bindings(nb::module_& m) {

m.def(
"get_depth",
[](const HashedWaveletOctree& map, Transformation3D pose,
PinholeCameraProjectorConfig cam_cfg, FloatingPoint threshold,
[](const HashedWaveletOctree& map, const Transformation3D& pose,
const PinholeCameraProjectorConfig& cam_cfg, FloatingPoint threshold,
FloatingPoint max_range) {
// NOTE: This way of parallelizing it is not very efficient, as it
// creates a very large number of jobs (1 per pixel), but already
// leads to a nice speedup. The next step to improve it would
// probably be to split the image into tiles and spawn 1 job per
// tile. The tile size should be such that there are enough tiles
// to distribute the work evenly across all cores even if some
// tiles take much shorter than others, while still being few
// enough to minimize the overhead of dispatching jobs and create
// local QueryAccelerator instances for each job. Maybe 10x as
// many tiles as there are workers?
ThreadPool thread_pool; // By default, the pool will spawn as many
// workers as the system's reported
// std::thread::hardware_concurrency().
Image depth_image(cam_cfg.width, cam_cfg.height);
QueryAccelerator query_accelerator(map);
const FloatingPoint mcw = map.getMinCellWidth();
const FloatingPoint min_cell_width = map.getMinCellWidth();
const PinholeCameraProjector projection_model(cam_cfg);
auto start_point = pose.getPosition();
const Point3D& start_point = pose.getPosition();
for (const Index2D& index :
Grid<2>(Index2D::Zero(),
depth_image.getDimensions() - Index2D::Ones())) {
const Vector2D image_xy = projection_model.indexToImage(index);
const Point3D C_point =
projection_model.sensorToCartesian({image_xy, max_range});
const Point3D end_point = pose * C_point;
FloatingPoint depth = raycast_fast(query_accelerator, start_point,
end_point, threshold, mcw);
depth_image.at(index) = depth;
FloatingPoint& depth_pixel = depth_image.at(index);
thread_pool.add_task([&map, &projection_model, &pose, &start_point,
&depth_pixel, index, max_range, threshold,
min_cell_width]() {
QueryAccelerator query_accelerator(map);
const Vector2D image_xy = projection_model.indexToImage(index);
const Point3D C_point =
projection_model.sensorToCartesian({image_xy, max_range});
const Point3D end_point = pose * C_point;
FloatingPoint depth =
raycast_fast(query_accelerator, start_point, end_point,
threshold, min_cell_width);
depth_pixel = depth;
});
}
thread_pool.wait_all();
return depth_image.getData().transpose().eval();
},
"Extract depth from octree map at using given camera pose and "
Expand Down

0 comments on commit fea5914

Please sign in to comment.