use std::mem::size_of; use std::num::NonZeroU64; use thiserror::Error; use crate::{ device::DeviceError, pipeline::{CreateComputePipelineError, CreateShaderModuleError}, }; #[derive(Clone, Debug, Error)] #[non_exhaustive] pub enum CreateDispatchIndirectValidationPipelineError { #[error(transparent)] DeviceError(#[from] DeviceError), #[error(transparent)] ShaderModule(#[from] CreateShaderModuleError), #[error(transparent)] ComputePipeline(#[from] CreateComputePipelineError), } /// This machinery requires the following limits: /// /// - max_bind_groups: 2, /// - max_dynamic_storage_buffers_per_pipeline_layout: 1, /// - max_storage_buffers_per_shader_stage: 2, /// - max_storage_buffer_binding_size: 3 * min_storage_buffer_offset_alignment, /// - max_push_constant_size: 4, /// - max_compute_invocations_per_workgroup 1 /// /// These are all indirectly satisfied by `DownlevelFlags::INDIRECT_EXECUTION`, which is also /// required for this module's functionality to work. #[derive(Debug)] pub struct IndirectValidation { module: Box, dst_bind_group_layout: Box, src_bind_group_layout: Box, pipeline_layout: Box, pipeline: Box, dst_buffer: Box, dst_bind_group: Box, } pub struct Params<'a> { pub pipeline_layout: &'a dyn hal::DynPipelineLayout, pub pipeline: &'a dyn hal::DynComputePipeline, pub dst_buffer: &'a dyn hal::DynBuffer, pub dst_bind_group: &'a dyn hal::DynBindGroup, pub aligned_offset: u64, pub offset_remainder: u64, } impl IndirectValidation { pub fn new( device: &dyn hal::DynDevice, limits: &wgt::Limits, ) -> Result { let max_compute_workgroups_per_dimension = limits.max_compute_workgroups_per_dimension; let src = format!( " @group(0) @binding(0) var dst: array; @group(1) @binding(0) var src: array; struct OffsetPc {{ inner: u32, }} var offset: OffsetPc; @compute @workgroup_size(1) fn main() {{ let src = vec3(src[offset.inner], src[offset.inner + 1], src[offset.inner + 2]); let max_compute_workgroups_per_dimension = {max_compute_workgroups_per_dimension}u; if ( src.x > max_compute_workgroups_per_dimension || src.y > max_compute_workgroups_per_dimension || src.z > max_compute_workgroups_per_dimension ) {{ dst = array(0u, 0u, 0u, 0u, 0u, 0u); }} else {{ dst = array(src.x, src.y, src.z, src.x, src.y, src.z); }} }} " ); // SAFETY: The value we are passing to `new_unchecked` is not zero, so this is safe. const SRC_BUFFER_SIZE: NonZeroU64 = unsafe { NonZeroU64::new_unchecked(size_of::() as u64 * 3) }; // SAFETY: The value we are passing to `new_unchecked` is not zero, so this is safe. const DST_BUFFER_SIZE: NonZeroU64 = unsafe { NonZeroU64::new_unchecked( SRC_BUFFER_SIZE.get() * 2, // From above: `dst: array` ) }; let module = naga::front::wgsl::parse_str(&src).map_err(|inner| { CreateShaderModuleError::Parsing(naga::error::ShaderError { source: src.clone(), label: None, inner: Box::new(inner), }) })?; let info = crate::device::create_validator( wgt::Features::PUSH_CONSTANTS, wgt::DownlevelFlags::empty(), naga::valid::ValidationFlags::all(), ) .validate(&module) .map_err(|inner| { CreateShaderModuleError::Validation(naga::error::ShaderError { source: src, label: None, inner: Box::new(inner), }) })?; let hal_shader = hal::ShaderInput::Naga(hal::NagaShader { module: std::borrow::Cow::Owned(module), info, debug_source: None, }); let hal_desc = hal::ShaderModuleDescriptor { label: None, runtime_checks: wgt::ShaderRuntimeChecks::unchecked(), }; let module = unsafe { device.create_shader_module(&hal_desc, hal_shader) }.map_err(|error| { match error { hal::ShaderError::Device(error) => { CreateShaderModuleError::Device(DeviceError::from_hal(error)) } hal::ShaderError::Compilation(ref msg) => { log::error!("Shader error: {}", msg); CreateShaderModuleError::Generation } } })?; let dst_bind_group_layout_desc = hal::BindGroupLayoutDescriptor { label: None, flags: hal::BindGroupLayoutFlags::empty(), entries: &[wgt::BindGroupLayoutEntry { binding: 0, visibility: wgt::ShaderStages::COMPUTE, ty: wgt::BindingType::Buffer { ty: wgt::BufferBindingType::Storage { read_only: false }, has_dynamic_offset: false, min_binding_size: Some(DST_BUFFER_SIZE), }, count: None, }], }; let dst_bind_group_layout = unsafe { device .create_bind_group_layout(&dst_bind_group_layout_desc) .map_err(DeviceError::from_hal)? }; let src_bind_group_layout_desc = hal::BindGroupLayoutDescriptor { label: None, flags: hal::BindGroupLayoutFlags::empty(), entries: &[wgt::BindGroupLayoutEntry { binding: 0, visibility: wgt::ShaderStages::COMPUTE, ty: wgt::BindingType::Buffer { ty: wgt::BufferBindingType::Storage { read_only: true }, has_dynamic_offset: true, min_binding_size: Some(SRC_BUFFER_SIZE), }, count: None, }], }; let src_bind_group_layout = unsafe { device .create_bind_group_layout(&src_bind_group_layout_desc) .map_err(DeviceError::from_hal)? }; let pipeline_layout_desc = hal::PipelineLayoutDescriptor { label: None, flags: hal::PipelineLayoutFlags::empty(), bind_group_layouts: &[ dst_bind_group_layout.as_ref(), src_bind_group_layout.as_ref(), ], push_constant_ranges: &[wgt::PushConstantRange { stages: wgt::ShaderStages::COMPUTE, range: 0..4, }], }; let pipeline_layout = unsafe { device .create_pipeline_layout(&pipeline_layout_desc) .map_err(DeviceError::from_hal)? }; let pipeline_desc = hal::ComputePipelineDescriptor { label: None, layout: pipeline_layout.as_ref(), stage: hal::ProgrammableStage { module: module.as_ref(), entry_point: "main", constants: &Default::default(), zero_initialize_workgroup_memory: false, }, cache: None, }; let pipeline = unsafe { device.create_compute_pipeline(&pipeline_desc) }.map_err(|err| match err { hal::PipelineError::Device(error) => { CreateComputePipelineError::Device(DeviceError::from_hal(error)) } hal::PipelineError::Linkage(_stages, msg) => { CreateComputePipelineError::Internal(msg) } hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal( crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(), ), hal::PipelineError::PipelineConstants(_, error) => { CreateComputePipelineError::PipelineConstants(error) } })?; let dst_buffer_desc = hal::BufferDescriptor { label: None, size: DST_BUFFER_SIZE.get(), usage: hal::BufferUses::INDIRECT | hal::BufferUses::STORAGE_READ_WRITE, memory_flags: hal::MemoryFlags::empty(), }; let dst_buffer = unsafe { device.create_buffer(&dst_buffer_desc) }.map_err(DeviceError::from_hal)?; let dst_bind_group_desc = hal::BindGroupDescriptor { label: None, layout: dst_bind_group_layout.as_ref(), entries: &[hal::BindGroupEntry { binding: 0, resource_index: 0, count: 1, }], buffers: &[hal::BufferBinding { buffer: dst_buffer.as_ref(), offset: 0, size: Some(DST_BUFFER_SIZE), }], samplers: &[], textures: &[], acceleration_structures: &[], }; let dst_bind_group = unsafe { device .create_bind_group(&dst_bind_group_desc) .map_err(DeviceError::from_hal) }?; Ok(Self { module, dst_bind_group_layout, src_bind_group_layout, pipeline_layout, pipeline, dst_buffer, dst_bind_group, }) } /// `Ok(None)` will only be returned if `buffer_size` is `0`. pub fn create_src_bind_group( &self, device: &dyn hal::DynDevice, limits: &wgt::Limits, buffer_size: u64, buffer: &dyn hal::DynBuffer, ) -> Result>, DeviceError> { let binding_size = calculate_src_buffer_binding_size(buffer_size, limits); let Some(binding_size) = NonZeroU64::new(binding_size) else { return Ok(None); }; let hal_desc = hal::BindGroupDescriptor { label: None, layout: self.src_bind_group_layout.as_ref(), entries: &[hal::BindGroupEntry { binding: 0, resource_index: 0, count: 1, }], buffers: &[hal::BufferBinding { buffer, offset: 0, size: Some(binding_size), }], samplers: &[], textures: &[], acceleration_structures: &[], }; unsafe { device .create_bind_group(&hal_desc) .map(Some) .map_err(DeviceError::from_hal) } } pub fn params<'a>(&'a self, limits: &wgt::Limits, offset: u64, buffer_size: u64) -> Params<'a> { // The offset we receive is only required to be aligned to 4 bytes. // // Binding offsets and dynamic offsets are required to be aligned to // min_storage_buffer_offset_alignment (256 bytes by default). // // So, we work around this limitation by calculating an aligned offset // and pass the remainder through a push constant. // // We could bind the whole buffer and only have to pass the offset // through a push constant but we might run into the // max_storage_buffer_binding_size limit. // // See the inner docs of `calculate_src_buffer_binding_size` to // see how we get the appropriate `binding_size`. let alignment = limits.min_storage_buffer_offset_alignment as u64; let binding_size = calculate_src_buffer_binding_size(buffer_size, limits); let aligned_offset = offset - offset % alignment; // This works because `binding_size` is either `buffer_size` or `alignment * 2 + buffer_size % alignment`. let max_aligned_offset = buffer_size - binding_size; let aligned_offset = aligned_offset.min(max_aligned_offset); let offset_remainder = offset - aligned_offset; Params { pipeline_layout: self.pipeline_layout.as_ref(), pipeline: self.pipeline.as_ref(), dst_buffer: self.dst_buffer.as_ref(), dst_bind_group: self.dst_bind_group.as_ref(), aligned_offset, offset_remainder, } } pub fn dispose(self, device: &dyn hal::DynDevice) { let IndirectValidation { module, dst_bind_group_layout, src_bind_group_layout, pipeline_layout, pipeline, dst_buffer, dst_bind_group, } = self; unsafe { device.destroy_bind_group(dst_bind_group); device.destroy_buffer(dst_buffer); device.destroy_compute_pipeline(pipeline); device.destroy_pipeline_layout(pipeline_layout); device.destroy_bind_group_layout(src_bind_group_layout); device.destroy_bind_group_layout(dst_bind_group_layout); device.destroy_shader_module(module); } } } fn calculate_src_buffer_binding_size(buffer_size: u64, limits: &wgt::Limits) -> u64 { let alignment = limits.min_storage_buffer_offset_alignment as u64; // We need to choose a binding size that can address all possible sets of 12 contiguous bytes in the buffer taking // into account that the dynamic offset needs to be a multiple of `min_storage_buffer_offset_alignment`. // Given the know variables: `offset`, `buffer_size`, `alignment` and the rule `offset + 12 <= buffer_size`. // Let `chunks = floor(buffer_size / alignment)`. // Let `chunk` be the interval `[0, chunks]`. // Let `offset = alignment * chunk + r` where `r` is the interval [0, alignment - 4]. // Let `binding` be the interval `[offset, offset + 12]`. // Let `aligned_offset = alignment * chunk`. // Let `aligned_binding` be the interval `[aligned_offset, aligned_offset + r + 12]`. // Let `aligned_binding_size = r + 12 = [12, alignment + 8]`. // Let `min_aligned_binding_size = alignment + 8`. // `min_aligned_binding_size` is the minimum binding size required to address all 12 contiguous bytes in the buffer // but the last aligned_offset + min_aligned_binding_size might overflow the buffer. In order to avoid this we must // pick a larger `binding_size` that satisfies: `last_aligned_offset + binding_size = buffer_size` and // `binding_size >= min_aligned_binding_size`. // Let `buffer_size = alignment * chunks + sr` where `sr` is the interval [0, alignment - 4]. // Let `last_aligned_offset = alignment * (chunks - u)` where `u` is the interval [0, chunks]. // => `binding_size = buffer_size - last_aligned_offset` // => `binding_size = alignment * chunks + sr - alignment * (chunks - u)` // => `binding_size = alignment * chunks + sr - alignment * chunks + alignment * u` // => `binding_size = sr + alignment * u` // => `min_aligned_binding_size <= sr + alignment * u` // => `alignment + 8 <= sr + alignment * u` // => `u` must be at least 2 // => `binding_size = sr + alignment * 2` let binding_size = 2 * alignment + (buffer_size % alignment); binding_size.min(buffer_size) }