Skip to content

Commit f860ee7

Browse files
authored
Solari: More accurate ReSTIR DI resampling (#22033)
* Adds MIS during reservoir merge which increases emissive light brightness, bringing it a little closer to the pathtraced reference * Tests final reservoir visibility for shading, but _not_ for resampling, which greatly improves shadow accuracy compared to the pathtraced reference, at the unfortunate cost of higher noise
1 parent 185712f commit f860ee7

File tree

5 files changed

+92
-58
lines changed

5 files changed

+92
-58
lines changed

crates/bevy_solari/src/realtime/restir_di.wgsl

Lines changed: 76 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#import bevy_solari::brdf::evaluate_brdf
1010
#import bevy_solari::gbuffer_utils::{gpixel_resolve, pixel_dissimilar, permute_pixel}
1111
#import bevy_solari::presample_light_tiles::{ResolvedLightSamplePacked, unpack_resolved_light_sample}
12-
#import bevy_solari::sampling::{LightSample, calculate_resolved_light_contribution, resolve_and_calculate_light_contribution, resolve_light_sample, trace_light_visibility}
12+
#import bevy_solari::sampling::{LightSample, calculate_resolved_light_contribution, resolve_and_calculate_light_contribution, resolve_light_sample, trace_light_visibility, balance_heuristic}
1313
#import bevy_solari::scene_bindings::{light_sources, previous_frame_light_id_translations, LIGHT_NOT_PRESENT_THIS_FRAME}
1414

1515
@group(1) @binding(0) var view_output: texture_storage_2d<rgba16float, read_write>;
@@ -49,8 +49,9 @@ fn initial_and_temporal(@builtin(workgroup_id) workgroup_id: vec3<u32>, @builtin
4949

5050
let diffuse_brdf = surface.material.base_color / PI;
5151
let initial_reservoir = generate_initial_reservoir(surface.world_position, surface.world_normal, diffuse_brdf, workgroup_id.xy, &rng);
52-
let temporal_reservoir = load_temporal_reservoir(global_id.xy, depth, surface.world_position, surface.world_normal);
53-
let merge_result = merge_reservoirs(initial_reservoir, temporal_reservoir, surface.world_position, surface.world_normal, diffuse_brdf, &rng);
52+
let temporal = load_temporal_reservoir(global_id.xy, depth, surface.world_position, surface.world_normal);
53+
let merge_result = merge_reservoirs(initial_reservoir, surface.world_position, surface.world_normal, diffuse_brdf,
54+
temporal.reservoir, temporal.world_position, temporal.world_normal, temporal.diffuse_brdf, &rng);
5455

5556
store_reservoir_b(global_id.xy, merge_result.merged_reservoir);
5657
}
@@ -71,16 +72,25 @@ fn spatial_and_shade(@builtin(global_invocation_id) global_id: vec3<u32>) {
7172

7273
let diffuse_brdf = surface.material.base_color / PI;
7374
let input_reservoir = load_reservoir_b(global_id.xy);
74-
let spatial_reservoir = load_spatial_reservoir(global_id.xy, depth, surface.world_position, surface.world_normal, &rng);
75-
let merge_result = merge_reservoirs(input_reservoir, spatial_reservoir, surface.world_position, surface.world_normal, diffuse_brdf, &rng);
75+
let spatial = load_spatial_reservoir(global_id.xy, depth, surface.world_position, surface.world_normal, &rng);
76+
let merge_result = merge_reservoirs(input_reservoir, surface.world_position, surface.world_normal, diffuse_brdf,
77+
spatial.reservoir, spatial.world_position, spatial.world_normal, spatial.diffuse_brdf, &rng);
7678
var combined_reservoir = merge_result.merged_reservoir;
7779

80+
// More accuracy, less stability
81+
#ifndef BIASED_RESAMPLING
82+
store_reservoir_a(global_id.xy, combined_reservoir);
83+
#endif
84+
7885
if reservoir_valid(combined_reservoir) {
7986
let resolved_light_sample = resolve_light_sample(combined_reservoir.sample, light_sources[combined_reservoir.sample.light_id >> 16u]);
8087
combined_reservoir.unbiased_contribution_weight *= trace_light_visibility(surface.world_position, resolved_light_sample.world_position);
8188
}
8289

90+
// More stability, less accuracy (shadows extend further out than they should)
91+
#ifdef BIASED_RESAMPLING
8392
store_reservoir_a(global_id.xy, combined_reservoir);
93+
#endif
8494

8595
let wo = normalize(view.world_position - surface.world_position);
8696
let brdf = evaluate_brdf(surface.world_normal, wo, merge_result.wi, surface.material);
@@ -135,63 +145,67 @@ fn generate_initial_reservoir(world_position: vec3<f32>, world_normal: vec3<f32>
135145
return reservoir;
136146
}
137147

138-
fn load_temporal_reservoir(pixel_id: vec2<u32>, depth: f32, world_position: vec3<f32>, world_normal: vec3<f32>) -> Reservoir {
148+
fn load_temporal_reservoir(pixel_id: vec2<u32>, depth: f32, world_position: vec3<f32>, world_normal: vec3<f32>) -> NeighborInfo {
139149
let motion_vector = textureLoad(motion_vectors, pixel_id, 0).xy;
140150
let temporal_pixel_id_float = round(vec2<f32>(pixel_id) - (motion_vector * view.main_pass_viewport.zw));
141151

142152
// Check if the current pixel was off screen during the previous frame (current pixel is newly visible),
143153
// or if all temporal history should assumed to be invalid
144154
if any(temporal_pixel_id_float < vec2(0.0)) || any(temporal_pixel_id_float >= view.main_pass_viewport.zw) || bool(constants.reset) {
145-
return empty_reservoir();
155+
return NeighborInfo(empty_reservoir(), vec3(0.0), vec3(0.0), vec3(0.0));
146156
}
147157

148158
let permuted_temporal_pixel_id = permute_pixel(vec2<u32>(temporal_pixel_id_float), constants.frame_index, view.viewport.zw);
149-
var temporal_reservoir = load_temporal_reservoir_inner(permuted_temporal_pixel_id, depth, world_position, world_normal);
159+
var temporal = load_temporal_reservoir_inner(permuted_temporal_pixel_id, depth, world_position, world_normal);
150160

151161
// If permuted reprojection failed (tends to happen on object edges), try point reprojection
152-
if !reservoir_valid(temporal_reservoir) {
153-
temporal_reservoir = load_temporal_reservoir_inner(vec2<u32>(temporal_pixel_id_float), depth, world_position, world_normal);
162+
if !reservoir_valid(temporal.reservoir) {
163+
temporal = load_temporal_reservoir_inner(vec2<u32>(temporal_pixel_id_float), depth, world_position, world_normal);
154164
}
155165

156166
// Check if the light selected in the previous frame no longer exists in the current frame (e.g. entity despawned)
157-
let previous_light_id = temporal_reservoir.sample.light_id >> 16u;
158-
let triangle_id = temporal_reservoir.sample.light_id & 0xFFFFu;
167+
let previous_light_id = temporal.reservoir.sample.light_id >> 16u;
168+
let triangle_id = temporal.reservoir.sample.light_id & 0xFFFFu;
159169
let light_id = previous_frame_light_id_translations[previous_light_id];
160170
if light_id == LIGHT_NOT_PRESENT_THIS_FRAME {
161-
return empty_reservoir();
171+
return NeighborInfo(empty_reservoir(), vec3(0.0), vec3(0.0), vec3(0.0));
162172
}
163-
temporal_reservoir.sample.light_id = (light_id << 16u) | triangle_id;
173+
temporal.reservoir.sample.light_id = (light_id << 16u) | triangle_id;
164174

165-
temporal_reservoir.confidence_weight = min(temporal_reservoir.confidence_weight, CONFIDENCE_WEIGHT_CAP);
175+
temporal.reservoir.confidence_weight = min(temporal.reservoir.confidence_weight, CONFIDENCE_WEIGHT_CAP);
166176

167-
return temporal_reservoir;
177+
return temporal;
168178
}
169179

170-
fn load_temporal_reservoir_inner(temporal_pixel_id: vec2<u32>, depth: f32, world_position: vec3<f32>, world_normal: vec3<f32>) -> Reservoir {
180+
fn load_temporal_reservoir_inner(temporal_pixel_id: vec2<u32>, depth: f32, world_position: vec3<f32>, world_normal: vec3<f32>) -> NeighborInfo {
171181
// Check if the pixel features have changed heavily between the current and previous frame
172182
let temporal_depth = textureLoad(previous_depth_buffer, temporal_pixel_id, 0);
173183
let temporal_surface = gpixel_resolve(textureLoad(previous_gbuffer, temporal_pixel_id, 0), temporal_depth, temporal_pixel_id, view.main_pass_viewport.zw, previous_view.world_from_clip);
184+
let temporal_diffuse_brdf = temporal_surface.material.base_color / PI;
174185
if pixel_dissimilar(depth, world_position, temporal_surface.world_position, world_normal, temporal_surface.world_normal, view) {
175-
return empty_reservoir();
186+
return NeighborInfo(empty_reservoir(), vec3(0.0), vec3(0.0), vec3(0.0));
176187
}
177188

178-
return load_reservoir_a(temporal_pixel_id);
189+
let temporal_reservoir = load_reservoir_a(temporal_pixel_id);
190+
return NeighborInfo(temporal_reservoir, temporal_surface.world_position, temporal_surface.world_normal, temporal_diffuse_brdf);
179191
}
180192

181-
fn load_spatial_reservoir(pixel_id: vec2<u32>, depth: f32, world_position: vec3<f32>, world_normal: vec3<f32>, rng: ptr<function, u32>) -> Reservoir {
193+
fn load_spatial_reservoir(pixel_id: vec2<u32>, depth: f32, world_position: vec3<f32>, world_normal: vec3<f32>, rng: ptr<function, u32>) -> NeighborInfo {
182194
for (var i = 0u; i < 5u; i++) {
183195
let spatial_pixel_id = get_neighbor_pixel_id(pixel_id, rng);
184196

185197
let spatial_depth = textureLoad(depth_buffer, spatial_pixel_id, 0);
186198
let spatial_surface = gpixel_resolve(textureLoad(gbuffer, spatial_pixel_id, 0), spatial_depth, spatial_pixel_id, view.main_pass_viewport.zw, view.world_from_clip);
199+
let spatial_diffuse_brdf = spatial_surface.material.base_color / PI;
187200
if pixel_dissimilar(depth, world_position, spatial_surface.world_position, world_normal, spatial_surface.world_normal, view) {
188201
continue;
189202
}
190203

191-
return load_reservoir_b(spatial_pixel_id);
204+
let spatial_reservoir = load_reservoir_b(spatial_pixel_id);
205+
return NeighborInfo(spatial_reservoir, spatial_surface.world_position, spatial_surface.world_normal, spatial_diffuse_brdf);
192206
}
193207

194-
return empty_reservoir();
208+
return NeighborInfo(empty_reservoir(), world_position, world_normal, vec3(0.0));
195209
}
196210

197211
fn get_neighbor_pixel_id(center_pixel_id: vec2<u32>, rng: ptr<function, u32>) -> vec2<u32> {
@@ -200,6 +214,13 @@ fn get_neighbor_pixel_id(center_pixel_id: vec2<u32>, rng: ptr<function, u32>) ->
200214
return vec2<u32>(spatial_id);
201215
}
202216

217+
struct NeighborInfo {
218+
reservoir: Reservoir,
219+
world_position: vec3<f32>,
220+
world_normal: vec3<f32>,
221+
diffuse_brdf: vec3<f32>,
222+
}
223+
203224
struct Reservoir {
204225
sample: LightSample,
205226
confidence_weight: f32,
@@ -252,42 +273,56 @@ struct ReservoirMergeResult {
252273

253274
fn merge_reservoirs(
254275
canonical_reservoir: Reservoir,
276+
canonical_world_position: vec3<f32>,
277+
canonical_world_normal: vec3<f32>,
278+
canonical_diffuse_brdf: vec3<f32>,
255279
other_reservoir: Reservoir,
256-
world_position: vec3<f32>,
257-
world_normal: vec3<f32>,
258-
diffuse_brdf: vec3<f32>,
280+
other_world_position: vec3<f32>,
281+
other_world_normal: vec3<f32>,
282+
other_diffuse_brdf: vec3<f32>,
259283
rng: ptr<function, u32>,
260284
) -> ReservoirMergeResult {
261-
let canonical_contribution = reservoir_contribution(canonical_reservoir, world_position, world_normal, diffuse_brdf);
262-
let other_contribution = reservoir_contribution(other_reservoir, world_position, world_normal, diffuse_brdf);
263-
264-
let mis_weight_denominator = 1.0 / (canonical_reservoir.confidence_weight + other_reservoir.confidence_weight);
265-
266-
let canonical_mis_weight = canonical_reservoir.confidence_weight * mis_weight_denominator;
267-
let canonical_resampling_weight = canonical_mis_weight * (canonical_contribution.target_function * canonical_reservoir.unbiased_contribution_weight);
268-
269-
let other_mis_weight = other_reservoir.confidence_weight * mis_weight_denominator;
270-
let other_resampling_weight = other_mis_weight * (other_contribution.target_function * other_reservoir.unbiased_contribution_weight);
285+
// Contributions for resampling
286+
let canonical_contribution_canonical_sample = reservoir_contribution(canonical_reservoir, canonical_world_position, canonical_world_normal, canonical_diffuse_brdf);
287+
let canonical_contribution_other_sample = reservoir_contribution(other_reservoir, canonical_world_position, canonical_world_normal, canonical_diffuse_brdf);
288+
289+
// Extra contributions for MIS
290+
let other_contribution_canonical_sample = reservoir_contribution(canonical_reservoir, other_world_position, other_world_normal, other_diffuse_brdf);
291+
let other_contribution_other_sample = reservoir_contribution(other_reservoir, other_world_position, other_world_normal, other_diffuse_brdf);
292+
293+
// Resampling weight for canonical sample
294+
let canonical_sample_mis_weight = balance_heuristic(
295+
canonical_reservoir.confidence_weight * canonical_contribution_canonical_sample.target_function,
296+
other_reservoir.confidence_weight * other_contribution_canonical_sample.target_function,
297+
);
298+
let canonical_sample_resampling_weight = canonical_sample_mis_weight * canonical_contribution_canonical_sample.target_function * canonical_reservoir.unbiased_contribution_weight;
271299

272-
let weight_sum = canonical_resampling_weight + other_resampling_weight;
300+
// Resampling weight for other sample
301+
let other_sample_mis_weight = balance_heuristic(
302+
other_reservoir.confidence_weight * other_contribution_other_sample.target_function,
303+
canonical_reservoir.confidence_weight * canonical_contribution_other_sample.target_function,
304+
);
305+
let other_sample_resampling_weight = other_sample_mis_weight * canonical_contribution_other_sample.target_function * other_reservoir.unbiased_contribution_weight;
273306

307+
// Perform resampling
274308
var combined_reservoir = empty_reservoir();
275309
combined_reservoir.confidence_weight = canonical_reservoir.confidence_weight + other_reservoir.confidence_weight;
310+
let weight_sum = canonical_sample_resampling_weight + other_sample_resampling_weight;
276311

277-
if rand_f(rng) < other_resampling_weight / weight_sum {
312+
if rand_f(rng) < other_sample_resampling_weight / weight_sum {
278313
combined_reservoir.sample = other_reservoir.sample;
279314

280-
let inverse_target_function = select(0.0, 1.0 / other_contribution.target_function, other_contribution.target_function > 0.0);
315+
let inverse_target_function = select(0.0, 1.0 / canonical_contribution_other_sample.target_function, canonical_contribution_other_sample.target_function > 0.0);
281316
combined_reservoir.unbiased_contribution_weight = weight_sum * inverse_target_function;
282317

283-
return ReservoirMergeResult(combined_reservoir, other_contribution.radiance, other_contribution.wi);
318+
return ReservoirMergeResult(combined_reservoir, canonical_contribution_other_sample.radiance, canonical_contribution_other_sample.wi);
284319
} else {
285320
combined_reservoir.sample = canonical_reservoir.sample;
286321

287-
let inverse_target_function = select(0.0, 1.0 / canonical_contribution.target_function, canonical_contribution.target_function > 0.0);
322+
let inverse_target_function = select(0.0, 1.0 / canonical_contribution_canonical_sample.target_function, canonical_contribution_canonical_sample.target_function > 0.0);
288323
combined_reservoir.unbiased_contribution_weight = weight_sum * inverse_target_function;
289324

290-
return ReservoirMergeResult(combined_reservoir, canonical_contribution.radiance, canonical_contribution.wi);
325+
return ReservoirMergeResult(combined_reservoir, canonical_contribution_canonical_sample.radiance, canonical_contribution_canonical_sample.wi);
291326
}
292327
}
293328

crates/bevy_solari/src/realtime/restir_gi.wgsl

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#import bevy_render::view::View
88
#import bevy_solari::brdf::evaluate_diffuse_brdf
99
#import bevy_solari::gbuffer_utils::{gpixel_resolve, pixel_dissimilar, permute_pixel}
10-
#import bevy_solari::sampling::{sample_random_light, trace_point_visibility}
10+
#import bevy_solari::sampling::{sample_random_light, trace_point_visibility, balance_heuristic}
1111
#import bevy_solari::scene_bindings::{trace_ray, resolve_ray_hit_full, RAY_T_MIN, RAY_T_MAX}
1212
#import bevy_solari::world_cache::{query_world_cache, WORLD_CACHE_CELL_LIFETIME}
1313

@@ -321,11 +321,3 @@ fn merge_reservoirs(
321321
return ReservoirMergeResult(combined_reservoir, canonical_sample_radiance);
322322
}
323323
}
324-
325-
fn balance_heuristic(x: f32, y: f32) -> f32 {
326-
let sum = x + y;
327-
if sum == 0.0 {
328-
return 0.0;
329-
}
330-
return x / sum;
331-
}

crates/bevy_solari/src/realtime/specular_gi.wgsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ fn nee_mis_weight(inverse_p_light: f32, brdf_rays_can_hit: bool, wo_tangent: vec
146146

147147
let p_light = 1.0 / inverse_p_light;
148148
let p_bounce = ggx_vndf_pdf(wo_tangent, wi_tangent, ray_hit.material.roughness);
149-
return max(0.0, power_heuristic(p_light, p_bounce));
149+
return power_heuristic(p_light, p_bounce);
150150
}
151151

152152
// Don't adjust the size of this struct without also adjusting GI_RESERVOIR_STRUCT_SIZE.

crates/bevy_solari/src/realtime/world_cache_query.wgsl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,19 @@ struct WorldCacheGeometryData {
4646
@group(1) @binding(23) var<storage, read_write> world_cache_active_cells_count: u32;
4747

4848
#ifndef WORLD_CACHE_NON_ATOMIC_LIFE_BUFFER
49-
fn query_world_cache(world_position: vec3<f32>, world_normal: vec3<f32>, view_position: vec3<f32>, cell_lifetime: u32, rng: ptr<function, u32>) -> vec3<f32> {
49+
fn query_world_cache(world_position_in: vec3<f32>, world_normal: vec3<f32>, view_position: vec3<f32>, cell_lifetime: u32, rng: ptr<function, u32>) -> vec3<f32> {
50+
var world_position = world_position_in;
5051
var cell_size = get_cell_size(world_position, view_position);
5152

5253
// https://tomclabault.github.io/blog/2025/regir, jitter_world_position_tangent_plane
54+
#ifndef NO_JITTER_WORLD_CACHE
5355
let TBN = orthonormalize(world_normal);
5456
let offset = (rand_vec2f(rng) * 2.0 - 1.0) * cell_size * 0.5;
55-
let jittered_position = world_position + offset.x * TBN[0] + offset.y * TBN[1];
56-
cell_size = get_cell_size(jittered_position, view_position);
57+
world_position += offset.x * TBN[0] + offset.y * TBN[1];
58+
cell_size = get_cell_size(world_position, view_position);
59+
#endif
5760

58-
let world_position_quantized = bitcast<vec3<u32>>(quantize_position(jittered_position, cell_size));
61+
let world_position_quantized = bitcast<vec3<u32>>(quantize_position(world_position, cell_size));
5962
let world_normal_quantized = bitcast<vec3<u32>>(quantize_normal(world_normal));
6063
var key = compute_key(world_position_quantized, world_normal_quantized);
6164
let checksum = compute_checksum(world_position_quantized, world_normal_quantized);
@@ -77,7 +80,7 @@ fn query_world_cache(world_position: vec3<f32>, world_normal: vec3<f32>, view_po
7780
return world_cache_radiance[key].rgb;
7881
} else if existing_checksum == WORLD_CACHE_EMPTY_CELL {
7982
// Cell is empty - initialize it
80-
world_cache_geometry_data[key].world_position = jittered_position;
83+
world_cache_geometry_data[key].world_position = world_position;
8184
world_cache_geometry_data[key].world_normal = world_normal;
8285
return vec3(0.0);
8386
} else {

crates/bevy_solari/src/scene/sampling.wgsl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,15 @@
66
#import bevy_solari::scene_bindings::{trace_ray, RAY_T_MIN, RAY_T_MAX, light_sources, directional_lights, LightSource, LIGHT_SOURCE_KIND_DIRECTIONAL, resolve_triangle_data_full, ResolvedRayHitFull}
77

88
fn power_heuristic(f: f32, g: f32) -> f32 {
9-
return f * f / (f * f + g * g);
9+
return balance_heuristic(f * f, g * g);
1010
}
1111

1212
fn balance_heuristic(f: f32, g: f32) -> f32 {
13-
return f / (f + g);
13+
let sum = f + g;
14+
if sum == 0.0 {
15+
return 0.0;
16+
}
17+
return max(0.0, f / sum);
1418
}
1519

1620
// https://gpuopen.com/download/Bounded_VNDF_Sampling_for_Smith-GGX_Reflections.pdf (Listing 1)

0 commit comments

Comments
 (0)