Skip to content
117 changes: 76 additions & 41 deletions crates/bevy_solari/src/realtime/restir_di.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#import bevy_solari::brdf::evaluate_brdf
#import bevy_solari::gbuffer_utils::{gpixel_resolve, pixel_dissimilar, permute_pixel}
#import bevy_solari::presample_light_tiles::{ResolvedLightSamplePacked, unpack_resolved_light_sample}
#import bevy_solari::sampling::{LightSample, calculate_resolved_light_contribution, resolve_and_calculate_light_contribution, resolve_light_sample, trace_light_visibility}
#import bevy_solari::sampling::{LightSample, calculate_resolved_light_contribution, resolve_and_calculate_light_contribution, resolve_light_sample, trace_light_visibility, balance_heuristic}
#import bevy_solari::scene_bindings::{light_sources, previous_frame_light_id_translations, LIGHT_NOT_PRESENT_THIS_FRAME}

@group(1) @binding(0) var view_output: texture_storage_2d<rgba16float, read_write>;
Expand Down Expand Up @@ -49,8 +49,9 @@ fn initial_and_temporal(@builtin(workgroup_id) workgroup_id: vec3<u32>, @builtin

let diffuse_brdf = surface.material.base_color / PI;
let initial_reservoir = generate_initial_reservoir(surface.world_position, surface.world_normal, diffuse_brdf, workgroup_id.xy, &rng);
let temporal_reservoir = load_temporal_reservoir(global_id.xy, depth, surface.world_position, surface.world_normal);
let merge_result = merge_reservoirs(initial_reservoir, temporal_reservoir, surface.world_position, surface.world_normal, diffuse_brdf, &rng);
let temporal = load_temporal_reservoir(global_id.xy, depth, surface.world_position, surface.world_normal);
let merge_result = merge_reservoirs(initial_reservoir, surface.world_position, surface.world_normal, diffuse_brdf,
temporal.reservoir, temporal.world_position, temporal.world_normal, temporal.diffuse_brdf, &rng);

store_reservoir_b(global_id.xy, merge_result.merged_reservoir);
}
Expand All @@ -71,16 +72,25 @@ fn spatial_and_shade(@builtin(global_invocation_id) global_id: vec3<u32>) {

let diffuse_brdf = surface.material.base_color / PI;
let input_reservoir = load_reservoir_b(global_id.xy);
let spatial_reservoir = load_spatial_reservoir(global_id.xy, depth, surface.world_position, surface.world_normal, &rng);
let merge_result = merge_reservoirs(input_reservoir, spatial_reservoir, surface.world_position, surface.world_normal, diffuse_brdf, &rng);
let spatial = load_spatial_reservoir(global_id.xy, depth, surface.world_position, surface.world_normal, &rng);
let merge_result = merge_reservoirs(input_reservoir, surface.world_position, surface.world_normal, diffuse_brdf,
spatial.reservoir, spatial.world_position, spatial.world_normal, spatial.diffuse_brdf, &rng);
var combined_reservoir = merge_result.merged_reservoir;

// More accuracy, less stability
#ifndef BIASED_RESAMPLING
store_reservoir_a(global_id.xy, combined_reservoir);
#endif

if reservoir_valid(combined_reservoir) {
let resolved_light_sample = resolve_light_sample(combined_reservoir.sample, light_sources[combined_reservoir.sample.light_id >> 16u]);
combined_reservoir.unbiased_contribution_weight *= trace_light_visibility(surface.world_position, resolved_light_sample.world_position);
}

// More stability, less accuracy (shadows extend further out than they should)
#ifdef BIASED_RESAMPLING
store_reservoir_a(global_id.xy, combined_reservoir);
#endif

let wo = normalize(view.world_position - surface.world_position);
let brdf = evaluate_brdf(surface.world_normal, wo, merge_result.wi, surface.material);
Expand Down Expand Up @@ -135,63 +145,67 @@ fn generate_initial_reservoir(world_position: vec3<f32>, world_normal: vec3<f32>
return reservoir;
}

fn load_temporal_reservoir(pixel_id: vec2<u32>, depth: f32, world_position: vec3<f32>, world_normal: vec3<f32>) -> Reservoir {
fn load_temporal_reservoir(pixel_id: vec2<u32>, depth: f32, world_position: vec3<f32>, world_normal: vec3<f32>) -> NeighborInfo {
let motion_vector = textureLoad(motion_vectors, pixel_id, 0).xy;
let temporal_pixel_id_float = round(vec2<f32>(pixel_id) - (motion_vector * view.main_pass_viewport.zw));

// Check if the current pixel was off screen during the previous frame (current pixel is newly visible),
// or if all temporal history should assumed to be invalid
if any(temporal_pixel_id_float < vec2(0.0)) || any(temporal_pixel_id_float >= view.main_pass_viewport.zw) || bool(constants.reset) {
return empty_reservoir();
return NeighborInfo(empty_reservoir(), vec3(0.0), vec3(0.0), vec3(0.0));
}

let permuted_temporal_pixel_id = permute_pixel(vec2<u32>(temporal_pixel_id_float), constants.frame_index, view.viewport.zw);
var temporal_reservoir = load_temporal_reservoir_inner(permuted_temporal_pixel_id, depth, world_position, world_normal);
var temporal = load_temporal_reservoir_inner(permuted_temporal_pixel_id, depth, world_position, world_normal);

// If permuted reprojection failed (tends to happen on object edges), try point reprojection
if !reservoir_valid(temporal_reservoir) {
temporal_reservoir = load_temporal_reservoir_inner(vec2<u32>(temporal_pixel_id_float), depth, world_position, world_normal);
if !reservoir_valid(temporal.reservoir) {
temporal = load_temporal_reservoir_inner(vec2<u32>(temporal_pixel_id_float), depth, world_position, world_normal);
}

// Check if the light selected in the previous frame no longer exists in the current frame (e.g. entity despawned)
let previous_light_id = temporal_reservoir.sample.light_id >> 16u;
let triangle_id = temporal_reservoir.sample.light_id & 0xFFFFu;
let previous_light_id = temporal.reservoir.sample.light_id >> 16u;
let triangle_id = temporal.reservoir.sample.light_id & 0xFFFFu;
let light_id = previous_frame_light_id_translations[previous_light_id];
if light_id == LIGHT_NOT_PRESENT_THIS_FRAME {
return empty_reservoir();
return NeighborInfo(empty_reservoir(), vec3(0.0), vec3(0.0), vec3(0.0));
}
temporal_reservoir.sample.light_id = (light_id << 16u) | triangle_id;
temporal.reservoir.sample.light_id = (light_id << 16u) | triangle_id;

temporal_reservoir.confidence_weight = min(temporal_reservoir.confidence_weight, CONFIDENCE_WEIGHT_CAP);
temporal.reservoir.confidence_weight = min(temporal.reservoir.confidence_weight, CONFIDENCE_WEIGHT_CAP);

return temporal_reservoir;
return temporal;
}

fn load_temporal_reservoir_inner(temporal_pixel_id: vec2<u32>, depth: f32, world_position: vec3<f32>, world_normal: vec3<f32>) -> Reservoir {
fn load_temporal_reservoir_inner(temporal_pixel_id: vec2<u32>, depth: f32, world_position: vec3<f32>, world_normal: vec3<f32>) -> NeighborInfo {
// Check if the pixel features have changed heavily between the current and previous frame
let temporal_depth = textureLoad(previous_depth_buffer, temporal_pixel_id, 0);
let temporal_surface = gpixel_resolve(textureLoad(previous_gbuffer, temporal_pixel_id, 0), temporal_depth, temporal_pixel_id, view.main_pass_viewport.zw, previous_view.world_from_clip);
let temporal_diffuse_brdf = temporal_surface.material.base_color / PI;
if pixel_dissimilar(depth, world_position, temporal_surface.world_position, world_normal, temporal_surface.world_normal, view) {
return empty_reservoir();
return NeighborInfo(empty_reservoir(), vec3(0.0), vec3(0.0), vec3(0.0));
}

return load_reservoir_a(temporal_pixel_id);
let temporal_reservoir = load_reservoir_a(temporal_pixel_id);
return NeighborInfo(temporal_reservoir, temporal_surface.world_position, temporal_surface.world_normal, temporal_diffuse_brdf);
}

fn load_spatial_reservoir(pixel_id: vec2<u32>, depth: f32, world_position: vec3<f32>, world_normal: vec3<f32>, rng: ptr<function, u32>) -> Reservoir {
fn load_spatial_reservoir(pixel_id: vec2<u32>, depth: f32, world_position: vec3<f32>, world_normal: vec3<f32>, rng: ptr<function, u32>) -> NeighborInfo {
for (var i = 0u; i < 5u; i++) {
let spatial_pixel_id = get_neighbor_pixel_id(pixel_id, rng);

let spatial_depth = textureLoad(depth_buffer, spatial_pixel_id, 0);
let spatial_surface = gpixel_resolve(textureLoad(gbuffer, spatial_pixel_id, 0), spatial_depth, spatial_pixel_id, view.main_pass_viewport.zw, view.world_from_clip);
let spatial_diffuse_brdf = spatial_surface.material.base_color / PI;
if pixel_dissimilar(depth, world_position, spatial_surface.world_position, world_normal, spatial_surface.world_normal, view) {
continue;
}

return load_reservoir_b(spatial_pixel_id);
let spatial_reservoir = load_reservoir_b(spatial_pixel_id);
return NeighborInfo(spatial_reservoir, spatial_surface.world_position, spatial_surface.world_normal, spatial_diffuse_brdf);
}

return empty_reservoir();
return NeighborInfo(empty_reservoir(), world_position, world_normal, vec3(0.0));
}

fn get_neighbor_pixel_id(center_pixel_id: vec2<u32>, rng: ptr<function, u32>) -> vec2<u32> {
Expand All @@ -200,6 +214,13 @@ fn get_neighbor_pixel_id(center_pixel_id: vec2<u32>, rng: ptr<function, u32>) ->
return vec2<u32>(spatial_id);
}

struct NeighborInfo {
reservoir: Reservoir,
world_position: vec3<f32>,
world_normal: vec3<f32>,
diffuse_brdf: vec3<f32>,
}

struct Reservoir {
sample: LightSample,
confidence_weight: f32,
Expand Down Expand Up @@ -252,42 +273,56 @@ struct ReservoirMergeResult {

fn merge_reservoirs(
canonical_reservoir: Reservoir,
canonical_world_position: vec3<f32>,
canonical_world_normal: vec3<f32>,
canonical_diffuse_brdf: vec3<f32>,
other_reservoir: Reservoir,
world_position: vec3<f32>,
world_normal: vec3<f32>,
diffuse_brdf: vec3<f32>,
other_world_position: vec3<f32>,
other_world_normal: vec3<f32>,
other_diffuse_brdf: vec3<f32>,
rng: ptr<function, u32>,
) -> ReservoirMergeResult {
let canonical_contribution = reservoir_contribution(canonical_reservoir, world_position, world_normal, diffuse_brdf);
let other_contribution = reservoir_contribution(other_reservoir, world_position, world_normal, diffuse_brdf);

let mis_weight_denominator = 1.0 / (canonical_reservoir.confidence_weight + other_reservoir.confidence_weight);

let canonical_mis_weight = canonical_reservoir.confidence_weight * mis_weight_denominator;
let canonical_resampling_weight = canonical_mis_weight * (canonical_contribution.target_function * canonical_reservoir.unbiased_contribution_weight);

let other_mis_weight = other_reservoir.confidence_weight * mis_weight_denominator;
let other_resampling_weight = other_mis_weight * (other_contribution.target_function * other_reservoir.unbiased_contribution_weight);
// Contributions for resampling
let canonical_contribution_canonical_sample = reservoir_contribution(canonical_reservoir, canonical_world_position, canonical_world_normal, canonical_diffuse_brdf);
let canonical_contribution_other_sample = reservoir_contribution(other_reservoir, canonical_world_position, canonical_world_normal, canonical_diffuse_brdf);

// Extra contributions for MIS
let other_contribution_canonical_sample = reservoir_contribution(canonical_reservoir, other_world_position, other_world_normal, other_diffuse_brdf);
let other_contribution_other_sample = reservoir_contribution(other_reservoir, other_world_position, other_world_normal, other_diffuse_brdf);

// Resampling weight for canonical sample
let canonical_sample_mis_weight = balance_heuristic(
canonical_reservoir.confidence_weight * canonical_contribution_canonical_sample.target_function,
other_reservoir.confidence_weight * other_contribution_canonical_sample.target_function,
);
let canonical_sample_resampling_weight = canonical_sample_mis_weight * canonical_contribution_canonical_sample.target_function * canonical_reservoir.unbiased_contribution_weight;

let weight_sum = canonical_resampling_weight + other_resampling_weight;
// Resampling weight for other sample
let other_sample_mis_weight = balance_heuristic(
other_reservoir.confidence_weight * other_contribution_other_sample.target_function,
canonical_reservoir.confidence_weight * canonical_contribution_other_sample.target_function,
);
let other_sample_resampling_weight = other_sample_mis_weight * canonical_contribution_other_sample.target_function * other_reservoir.unbiased_contribution_weight;

// Perform resampling
var combined_reservoir = empty_reservoir();
combined_reservoir.confidence_weight = canonical_reservoir.confidence_weight + other_reservoir.confidence_weight;
let weight_sum = canonical_sample_resampling_weight + other_sample_resampling_weight;

if rand_f(rng) < other_resampling_weight / weight_sum {
if rand_f(rng) < other_sample_resampling_weight / weight_sum {
combined_reservoir.sample = other_reservoir.sample;

let inverse_target_function = select(0.0, 1.0 / other_contribution.target_function, other_contribution.target_function > 0.0);
let inverse_target_function = select(0.0, 1.0 / canonical_contribution_other_sample.target_function, canonical_contribution_other_sample.target_function > 0.0);
combined_reservoir.unbiased_contribution_weight = weight_sum * inverse_target_function;

return ReservoirMergeResult(combined_reservoir, other_contribution.radiance, other_contribution.wi);
return ReservoirMergeResult(combined_reservoir, canonical_contribution_other_sample.radiance, canonical_contribution_other_sample.wi);
} else {
combined_reservoir.sample = canonical_reservoir.sample;

let inverse_target_function = select(0.0, 1.0 / canonical_contribution.target_function, canonical_contribution.target_function > 0.0);
let inverse_target_function = select(0.0, 1.0 / canonical_contribution_canonical_sample.target_function, canonical_contribution_canonical_sample.target_function > 0.0);
combined_reservoir.unbiased_contribution_weight = weight_sum * inverse_target_function;

return ReservoirMergeResult(combined_reservoir, canonical_contribution.radiance, canonical_contribution.wi);
return ReservoirMergeResult(combined_reservoir, canonical_contribution_canonical_sample.radiance, canonical_contribution_canonical_sample.wi);
}
}

Expand Down
10 changes: 1 addition & 9 deletions crates/bevy_solari/src/realtime/restir_gi.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#import bevy_render::view::View
#import bevy_solari::brdf::evaluate_diffuse_brdf
#import bevy_solari::gbuffer_utils::{gpixel_resolve, pixel_dissimilar, permute_pixel}
#import bevy_solari::sampling::{sample_random_light, trace_point_visibility}
#import bevy_solari::sampling::{sample_random_light, trace_point_visibility, balance_heuristic}
#import bevy_solari::scene_bindings::{trace_ray, resolve_ray_hit_full, RAY_T_MIN, RAY_T_MAX}
#import bevy_solari::world_cache::{query_world_cache, WORLD_CACHE_CELL_LIFETIME}

Expand Down Expand Up @@ -321,11 +321,3 @@ fn merge_reservoirs(
return ReservoirMergeResult(combined_reservoir, canonical_sample_radiance);
}
}

fn balance_heuristic(x: f32, y: f32) -> f32 {
let sum = x + y;
if sum == 0.0 {
return 0.0;
}
return x / sum;
}
2 changes: 1 addition & 1 deletion crates/bevy_solari/src/realtime/specular_gi.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ fn nee_mis_weight(inverse_p_light: f32, brdf_rays_can_hit: bool, wo_tangent: vec

let p_light = 1.0 / inverse_p_light;
let p_bounce = ggx_vndf_pdf(wo_tangent, wi_tangent, ray_hit.material.roughness);
return max(0.0, power_heuristic(p_light, p_bounce));
return power_heuristic(p_light, p_bounce);
}

// Don't adjust the size of this struct without also adjusting GI_RESERVOIR_STRUCT_SIZE.
Expand Down
13 changes: 8 additions & 5 deletions crates/bevy_solari/src/realtime/world_cache_query.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,19 @@ struct WorldCacheGeometryData {
@group(1) @binding(23) var<storage, read_write> world_cache_active_cells_count: u32;

#ifndef WORLD_CACHE_NON_ATOMIC_LIFE_BUFFER
fn query_world_cache(world_position: vec3<f32>, world_normal: vec3<f32>, view_position: vec3<f32>, cell_lifetime: u32, rng: ptr<function, u32>) -> vec3<f32> {
fn query_world_cache(world_position_in: vec3<f32>, world_normal: vec3<f32>, view_position: vec3<f32>, cell_lifetime: u32, rng: ptr<function, u32>) -> vec3<f32> {
var world_position = world_position_in;
var cell_size = get_cell_size(world_position, view_position);

// https://tomclabault.github.io/blog/2025/regir, jitter_world_position_tangent_plane
#ifndef NO_JITTER_WORLD_CACHE
let TBN = orthonormalize(world_normal);
let offset = (rand_vec2f(rng) * 2.0 - 1.0) * cell_size * 0.5;
let jittered_position = world_position + offset.x * TBN[0] + offset.y * TBN[1];
cell_size = get_cell_size(jittered_position, view_position);
world_position += offset.x * TBN[0] + offset.y * TBN[1];
cell_size = get_cell_size(world_position, view_position);
#endif

let world_position_quantized = bitcast<vec3<u32>>(quantize_position(jittered_position, cell_size));
let world_position_quantized = bitcast<vec3<u32>>(quantize_position(world_position, cell_size));
let world_normal_quantized = bitcast<vec3<u32>>(quantize_normal(world_normal));
var key = compute_key(world_position_quantized, world_normal_quantized);
let checksum = compute_checksum(world_position_quantized, world_normal_quantized);
Expand All @@ -77,7 +80,7 @@ fn query_world_cache(world_position: vec3<f32>, world_normal: vec3<f32>, view_po
return world_cache_radiance[key].rgb;
} else if existing_checksum == WORLD_CACHE_EMPTY_CELL {
// Cell is empty - initialize it
world_cache_geometry_data[key].world_position = jittered_position;
world_cache_geometry_data[key].world_position = world_position;
world_cache_geometry_data[key].world_normal = world_normal;
return vec3(0.0);
} else {
Expand Down
8 changes: 6 additions & 2 deletions crates/bevy_solari/src/scene/sampling.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@
#import bevy_solari::scene_bindings::{trace_ray, RAY_T_MIN, RAY_T_MAX, light_sources, directional_lights, LightSource, LIGHT_SOURCE_KIND_DIRECTIONAL, resolve_triangle_data_full, ResolvedRayHitFull}

fn power_heuristic(f: f32, g: f32) -> f32 {
return f * f / (f * f + g * g);
return balance_heuristic(f * f, g * g);
}

fn balance_heuristic(f: f32, g: f32) -> f32 {
return f / (f + g);
let sum = f + g;
if sum == 0.0 {
return 0.0;
}
return max(0.0, f / sum);
}

// https://gpuopen.com/download/Bounded_VNDF_Sampling_for_Smith-GGX_Reflections.pdf (Listing 1)
Expand Down