Skip to content

Commit

Permalink
Merge pull request #73 from madmann91/generic_traversal
Browse files Browse the repository at this point in the history
Add generic traversal function
  • Loading branch information
madmann91 authored Feb 17, 2024
2 parents a3a9fe9 + ed03eef commit 921e277
Showing 1 changed file with 41 additions and 24 deletions.
65 changes: 41 additions & 24 deletions src/bvh/v2/bvh.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ template <typename Node>
struct Bvh {
using Index = typename Node::Index;
using Scalar = typename Node::Scalar;
using Ray = bvh::v2::Ray<Scalar, Node::dimension>;

std::vector<Node> nodes;
std::vector<size_t> prim_ids;
Expand All @@ -33,13 +34,21 @@ struct Bvh {
/// Extracts the BVH rooted at the given node index.
inline Bvh extract_bvh(size_t root_id) const;

/// Traverses the BVH from the given index in `start` using the provided stack. Every leaf
/// encountered on the way is processed using the given `LeafFn` function, and every pair of
/// nodes is processed with the function in `HitFn`, which returns a triplet of booleans
/// indicating whether the first child should be processed, whether the second child should be
/// processed, and whether to traverse the second child first instead of the other way around.
template <bool IsAnyHit, typename Stack, typename LeafFn, typename InnerFn>
inline void traverse(Index start, Stack&, LeafFn&&, InnerFn&&) const;

/// Intersects the BVH with a single ray, using the given function to intersect the contents
/// of a leaf. The algorithm starts at the node index `top` and uses the given stack object.
/// of a leaf. The algorithm starts at the node index `start` and uses the given stack object.
/// When `IsAnyHit` is true, the function stops at the first intersection (useful for shadow
/// rays), otherwise it finds the closest intersection. When `IsRobust` is true, a slower but
/// numerically robust ray-box test is used, otherwise a fast, but less precise test is used.
template <bool IsAnyHit, bool IsRobust, typename Stack, typename LeafFn, typename InnerFn = IgnoreArgs>
inline void intersect(Ray<Scalar, Node::dimension>& ray, Index top, Stack&, LeafFn&&, InnerFn&& = {}) const;
inline void intersect(const Ray& ray, Index start, Stack&, LeafFn&&, InnerFn&& = {}) const;

inline void serialize(OutputStream&) const;
static inline Bvh deserialize(InputStream&);
Expand Down Expand Up @@ -79,40 +88,23 @@ auto Bvh<Node>::extract_bvh(size_t root_id) const -> Bvh {
}

template <typename Node>
template <bool IsAnyHit, bool IsRobust, typename Stack, typename LeafFn, typename InnerFn>
void Bvh<Node>::intersect(Ray<Scalar, Node::dimension>& ray, Index start, Stack& stack, LeafFn&& leaf_fn, InnerFn&& inner_fn) const {
auto inv_dir = ray.template get_inv_dir<!IsRobust>();
auto inv_org = -inv_dir * ray.org;
auto inv_dir_pad = Ray<Scalar, Node::dimension>::pad_inv_dir(inv_dir);
auto octant = ray.get_octant();

auto intersect_node = [&] (const Node& node) {
return IsRobust
? node.intersect_robust(ray, inv_dir, inv_dir_pad, octant)
: node.intersect_fast(ray, inv_dir, inv_org, octant);
};

template <bool IsAnyHit, typename Stack, typename LeafFn, typename InnerFn>
void Bvh<Node>::traverse(Index start, Stack& stack, LeafFn&& leaf_fn, InnerFn&& inner_fn) const
{
stack.push(start);
restart:
while (!stack.is_empty()) {
auto top = stack.pop();
while (top.prim_count == 0) {
auto& left = nodes[top.first_id];
auto& right = nodes[top.first_id + 1];

inner_fn(left, right);

auto intr_left = intersect_node(left);
auto intr_right = intersect_node(right);

bool hit_left = intr_left.first <= intr_left.second;
bool hit_right = intr_right.first <= intr_right.second;
auto [hit_left, hit_right, should_swap] = inner_fn(left, right);

if (hit_left) {
auto near_index = left.index;
if (hit_right) {
auto far_index = right.index;
if (!IsAnyHit && intr_left.first > intr_right.first)
if (should_swap)
std::swap(near_index, far_index);
stack.push(far_index);
}
Expand All @@ -130,6 +122,31 @@ void Bvh<Node>::intersect(Ray<Scalar, Node::dimension>& ray, Index start, Stack&
}
}

template <typename Node>
template <bool IsAnyHit, bool IsRobust, typename Stack, typename LeafFn, typename InnerFn>
void Bvh<Node>::intersect(const Ray& ray, Index start, Stack& stack, LeafFn&& leaf_fn, InnerFn&& inner_fn) const {
auto inv_dir = ray.template get_inv_dir<!IsRobust>();
auto inv_org = -inv_dir * ray.org;
auto inv_dir_pad = ray.pad_inv_dir(inv_dir);
auto octant = ray.get_octant();

traverse<IsAnyHit>(start, stack, leaf_fn, [&] (const Node& left, const Node& right) {
inner_fn(left, right);
std::pair<Scalar, Scalar> intr_left, intr_right;
if constexpr (IsRobust) {
intr_left = left.intersect_robust(ray, inv_dir, inv_dir_pad, octant);
intr_right = right.intersect_robust(ray, inv_dir, inv_org, octant);
} else {
intr_left = left.intersect_fast(ray, inv_dir, inv_org, octant);
intr_right = right.intersect_fast(ray, inv_dir, inv_org, octant);
}
return std::make_tuple(
intr_left.first <= intr_left.second,
intr_right.first <= intr_right.second,
!IsAnyHit && intr_left.first > intr_right.first);
});
}

template <typename Node>
void Bvh<Node>::serialize(OutputStream& stream) const {
stream.write(nodes.size());
Expand Down

0 comments on commit 921e277

Please sign in to comment.