use indexmap::IndexMap; use naga::{ Arena, AtomicFunction, Block, Constant, EntryPoint, Expression, Function, FunctionArgument, FunctionResult, GatherMode, GlobalVariable, Handle, ImageQuery, LocalVariable, Module, Override, SampleLevel, Span, Statement, StructMember, SwitchCase, Type, TypeInner, UniqueArena, }; use std::{cell::RefCell, rc::Rc}; #[derive(Debug, Default)] pub struct DerivedModule<'a> { shader: Option<&'a Module>, span_offset: usize, /// Maps the original type handle to the the mangled type handle. type_map: IndexMap, Handle>, /// Maps the original const handle to the the mangled const handle. const_map: IndexMap, Handle>, /// Maps the original pipeline override handle to the the mangled pipeline override handle. pipeline_override_map: IndexMap, Handle>, /// Contains both const expressions and pipeline override constant expressions. /// The expressions are stored together because that's what Naga expects. global_expressions: Rc>>, /// Maps the original expression handle to the new expression handle for const expressions and pipeline override expressions. /// The expressions are stored together because that's what Naga expects. global_expression_map: Rc, Handle>>>, global_map: IndexMap, Handle>, function_map: IndexMap>, types: UniqueArena, constants: Arena, globals: Arena, functions: Arena, pipeline_overrides: Arena, special_types: naga::SpecialTypes, } impl<'a> DerivedModule<'a> { // set source context for import operations pub fn set_shader_source(&mut self, shader: &'a Module, span_offset: usize) { self.clear_shader_source(); self.shader = Some(shader); self.span_offset = span_offset; // eagerly import special types if let Some(h_special_type) = shader.special_types.ray_desc.as_ref() { if let Some(derived_special_type) = self.special_types.ray_desc.as_ref() { self.type_map.insert(*h_special_type, *derived_special_type); } else { self.special_types.ray_desc = Some(self.import_type(h_special_type)); } } if let Some(h_special_type) = shader.special_types.ray_intersection.as_ref() { if let Some(derived_special_type) = self.special_types.ray_intersection.as_ref() { self.type_map.insert(*h_special_type, *derived_special_type); } else { self.special_types.ray_intersection = Some(self.import_type(h_special_type)); } } for (predeclared, h_predeclared_type) in shader.special_types.predeclared_types.iter() { if let Some(derived_special_type) = self.special_types.predeclared_types.get(predeclared) { self.type_map .insert(*h_predeclared_type, *derived_special_type); } else { let new_h = self.import_type(h_predeclared_type); self.special_types .predeclared_types .insert(predeclared.clone(), new_h); } } } // detach source context pub fn clear_shader_source(&mut self) { self.shader = None; self.type_map.clear(); self.const_map.clear(); self.global_map.clear(); self.global_expression_map.borrow_mut().clear(); self.pipeline_override_map.clear(); } pub fn map_span(&self, span: Span) -> Span { let span = span.to_range(); match span { Some(rng) => Span::new( (rng.start + self.span_offset) as u32, (rng.end + self.span_offset) as u32, ), None => Span::UNDEFINED, } } // remap a type from source context into our derived context pub fn import_type(&mut self, h_type: &Handle) -> Handle { self.rename_type(h_type, None) } // remap a type from source context into our derived context, and rename it pub fn rename_type(&mut self, h_type: &Handle, name: Option) -> Handle { self.type_map.get(h_type).copied().unwrap_or_else(|| { let ty = self .shader .as_ref() .unwrap() .types .get_handle(*h_type) .unwrap(); let name = match name { Some(name) => Some(name), None => ty.name.clone(), }; let new_type = Type { name, inner: match &ty.inner { TypeInner::Scalar { .. } | TypeInner::Vector { .. } | TypeInner::Matrix { .. } | TypeInner::ValuePointer { .. } | TypeInner::Image { .. } | TypeInner::Sampler { .. } | TypeInner::Atomic { .. } | TypeInner::AccelerationStructure | TypeInner::RayQuery => ty.inner.clone(), TypeInner::Pointer { base, space } => TypeInner::Pointer { base: self.import_type(base), space: *space, }, TypeInner::Struct { members, span } => { let members = members .iter() .map(|m| StructMember { name: m.name.clone(), ty: self.import_type(&m.ty), binding: m.binding.clone(), offset: m.offset, }) .collect(); TypeInner::Struct { members, span: *span, } } TypeInner::Array { base, size, stride } => TypeInner::Array { base: self.import_type(base), size: *size, stride: *stride, }, TypeInner::BindingArray { base, size } => TypeInner::BindingArray { base: self.import_type(base), size: *size, }, }, }; let span = self.shader.as_ref().unwrap().types.get_span(*h_type); let new_h = self.types.insert(new_type, self.map_span(span)); self.type_map.insert(*h_type, new_h); new_h }) } // remap a const from source context into our derived context pub fn import_const(&mut self, h_const: &Handle) -> Handle { self.const_map.get(h_const).copied().unwrap_or_else(|| { let c = self .shader .as_ref() .unwrap() .constants .try_get(*h_const) .unwrap(); let new_const = Constant { name: c.name.clone(), ty: self.import_type(&c.ty), init: self.import_global_expression(c.init), }; let span = self.shader.as_ref().unwrap().constants.get_span(*h_const); let new_h = self .constants .fetch_or_append(new_const, self.map_span(span)); self.const_map.insert(*h_const, new_h); new_h }) } // remap a global from source context into our derived context pub fn import_global(&mut self, h_global: &Handle) -> Handle { self.global_map.get(h_global).copied().unwrap_or_else(|| { let gv = self .shader .as_ref() .unwrap() .global_variables .try_get(*h_global) .unwrap(); let new_global = GlobalVariable { name: gv.name.clone(), space: gv.space, binding: gv.binding.clone(), ty: self.import_type(&gv.ty), init: gv.init.map(|c| self.import_global_expression(c)), }; let span = self .shader .as_ref() .unwrap() .global_variables .get_span(*h_global); let new_h = self .globals .fetch_or_append(new_global, self.map_span(span)); self.global_map.insert(*h_global, new_h); new_h }) } // remap either a const or pipeline override expression from source context into our derived context pub fn import_global_expression(&mut self, h_expr: Handle) -> Handle { self.import_expression( h_expr, &self.shader.as_ref().unwrap().global_expressions, self.global_expression_map.clone(), self.global_expressions.clone(), false, true, ) } // remap a pipeline override from source context into our derived context pub fn import_pipeline_override(&mut self, h_override: &Handle) -> Handle { self.pipeline_override_map .get(h_override) .copied() .unwrap_or_else(|| { let pipeline_override = self .shader .as_ref() .unwrap() .overrides .try_get(*h_override) .unwrap(); let new_override = Override { name: pipeline_override.name.clone(), id: pipeline_override.id, ty: self.import_type(&pipeline_override.ty), init: pipeline_override .init .map(|init| self.import_global_expression(init)), }; let span = self .shader .as_ref() .unwrap() .overrides .get_span(*h_override); let new_h = self .pipeline_overrides .fetch_or_append(new_override, self.map_span(span)); self.pipeline_override_map.insert(*h_override, new_h); new_h }) } // remap a block fn import_block( &mut self, block: &Block, old_expressions: &Arena, already_imported: Rc, Handle>>>, new_expressions: Rc>>, ) -> Block { macro_rules! map_expr { ($e:expr) => { self.import_expression( *$e, old_expressions, already_imported.clone(), new_expressions.clone(), false, false, ) }; } macro_rules! map_expr_opt { ($e:expr) => { $e.as_ref().map(|expr| map_expr!(expr)) }; } macro_rules! map_block { ($b:expr) => { self.import_block( $b, old_expressions, already_imported.clone(), new_expressions.clone(), ) }; } let statements = block .iter() .map(|stmt| { match stmt { // remap function calls Statement::Call { function, arguments, result, } => Statement::Call { function: self.map_function_handle(function), arguments: arguments.iter().map(|expr| map_expr!(expr)).collect(), result: result.as_ref().map(|result| map_expr!(result)), }, // recursively Statement::Block(b) => Statement::Block(map_block!(b)), Statement::If { condition, accept, reject, } => Statement::If { condition: map_expr!(condition), accept: map_block!(accept), reject: map_block!(reject), }, Statement::Switch { selector, cases } => Statement::Switch { selector: map_expr!(selector), cases: cases .iter() .map(|case| SwitchCase { value: case.value, body: map_block!(&case.body), fall_through: case.fall_through, }) .collect(), }, Statement::Loop { body, continuing, break_if, } => Statement::Loop { body: map_block!(body), continuing: map_block!(continuing), break_if: map_expr_opt!(break_if), }, // map expressions Statement::Emit(exprs) => { // iterate once to add expressions that should NOT be part of the emit statement for expr in exprs.clone() { self.import_expression( expr, old_expressions, already_imported.clone(), new_expressions.clone(), true, false, ); } let old_length = new_expressions.borrow().len(); // iterate again to add expressions that should be part of the emit statement for expr in exprs.clone() { map_expr!(&expr); } Statement::Emit(new_expressions.borrow().range_from(old_length)) } Statement::Store { pointer, value } => Statement::Store { pointer: map_expr!(pointer), value: map_expr!(value), }, Statement::ImageStore { image, coordinate, array_index, value, } => Statement::ImageStore { image: map_expr!(image), coordinate: map_expr!(coordinate), array_index: map_expr_opt!(array_index), value: map_expr!(value), }, Statement::Atomic { pointer, fun, value, result, } => { let fun = match fun { AtomicFunction::Exchange { compare: Some(compare_expr), } => AtomicFunction::Exchange { compare: Some(map_expr!(compare_expr)), }, fun => *fun, }; Statement::Atomic { pointer: map_expr!(pointer), fun, value: map_expr!(value), result: map_expr_opt!(result), } } Statement::WorkGroupUniformLoad { pointer, result } => { Statement::WorkGroupUniformLoad { pointer: map_expr!(pointer), result: map_expr!(result), } } Statement::Return { value } => Statement::Return { value: map_expr_opt!(value), }, Statement::RayQuery { query, fun } => Statement::RayQuery { query: map_expr!(query), fun: match fun { naga::RayQueryFunction::Initialize { acceleration_structure, descriptor, } => naga::RayQueryFunction::Initialize { acceleration_structure: map_expr!(acceleration_structure), descriptor: map_expr!(descriptor), }, naga::RayQueryFunction::Proceed { result } => { naga::RayQueryFunction::Proceed { result: map_expr!(result), } } naga::RayQueryFunction::Terminate => naga::RayQueryFunction::Terminate, }, }, Statement::SubgroupBallot { result, predicate } => Statement::SubgroupBallot { result: map_expr!(result), predicate: map_expr_opt!(predicate), }, Statement::SubgroupGather { mut mode, argument, result, } => { match mode { GatherMode::BroadcastFirst => (), GatherMode::Broadcast(ref mut h_src) | GatherMode::Shuffle(ref mut h_src) | GatherMode::ShuffleDown(ref mut h_src) | GatherMode::ShuffleUp(ref mut h_src) | GatherMode::ShuffleXor(ref mut h_src) => *h_src = map_expr!(h_src), }; Statement::SubgroupGather { mode, argument: map_expr!(argument), result: map_expr!(result), } } Statement::SubgroupCollectiveOperation { op, collective_op, argument, result, } => Statement::SubgroupCollectiveOperation { op: *op, collective_op: *collective_op, argument: map_expr!(argument), result: map_expr!(result), }, Statement::ImageAtomic { image, coordinate, array_index, fun, value, } => { let fun = match fun { AtomicFunction::Exchange { compare: Some(compare_expr), } => AtomicFunction::Exchange { compare: Some(map_expr!(compare_expr)), }, fun => *fun, }; Statement::ImageAtomic { image: map_expr!(image), coordinate: map_expr!(coordinate), array_index: map_expr_opt!(array_index), fun, value: map_expr!(value), } } // else just copy Statement::Break | Statement::Continue | Statement::Kill | Statement::Barrier(_) => stmt.clone(), } }) .collect(); let mut new_block = Block::from_vec(statements); for ((_, new_span), (_, old_span)) in new_block.span_iter_mut().zip(block.span_iter()) { *new_span.unwrap() = self.map_span(*old_span); } new_block } fn import_expression( &mut self, h_expr: Handle, old_expressions: &Arena, already_imported: Rc, Handle>>>, new_expressions: Rc>>, non_emitting_only: bool, // only brings items that should NOT be emitted into scope unique: bool, // ensure expressions are unique with custom comparison ) -> Handle { if let Some(h_new) = already_imported.borrow().get(&h_expr) { return *h_new; } macro_rules! map_expr { ($e:expr) => { self.import_expression( *$e, old_expressions, already_imported.clone(), new_expressions.clone(), non_emitting_only, unique, ) }; } macro_rules! map_expr_opt { ($e:expr) => { $e.as_ref().map(|expr| { self.import_expression( *expr, old_expressions, already_imported.clone(), new_expressions.clone(), non_emitting_only, unique, ) }) }; } let mut is_external = false; let expr = old_expressions.try_get(h_expr).unwrap(); let expr = match expr { Expression::Literal(_) => { is_external = true; expr.clone() } Expression::ZeroValue(zv) => { is_external = true; Expression::ZeroValue(self.import_type(zv)) } Expression::CallResult(f) => Expression::CallResult(self.map_function_handle(f)), Expression::Constant(c) => { is_external = true; Expression::Constant(self.import_const(c)) } Expression::Compose { ty, components } => Expression::Compose { ty: self.import_type(ty), components: components.iter().map(|expr| map_expr!(expr)).collect(), }, Expression::GlobalVariable(gv) => { is_external = true; Expression::GlobalVariable(self.import_global(gv)) } Expression::ImageSample { image, sampler, gather, coordinate, array_index, offset, level, depth_ref, } => Expression::ImageSample { image: map_expr!(image), sampler: map_expr!(sampler), gather: *gather, coordinate: map_expr!(coordinate), array_index: map_expr_opt!(array_index), offset: offset.map(|c| self.import_global_expression(c)), level: match level { SampleLevel::Auto | SampleLevel::Zero => *level, SampleLevel::Exact(expr) => SampleLevel::Exact(map_expr!(expr)), SampleLevel::Bias(expr) => SampleLevel::Bias(map_expr!(expr)), SampleLevel::Gradient { x, y } => SampleLevel::Gradient { x: map_expr!(x), y: map_expr!(y), }, }, depth_ref: map_expr_opt!(depth_ref), }, Expression::Access { base, index } => Expression::Access { base: map_expr!(base), index: map_expr!(index), }, Expression::AccessIndex { base, index } => Expression::AccessIndex { base: map_expr!(base), index: *index, }, Expression::Splat { size, value } => Expression::Splat { size: *size, value: map_expr!(value), }, Expression::Swizzle { size, vector, pattern, } => Expression::Swizzle { size: *size, vector: map_expr!(vector), pattern: *pattern, }, Expression::Load { pointer } => Expression::Load { pointer: map_expr!(pointer), }, Expression::ImageLoad { image, coordinate, array_index, sample, level, } => Expression::ImageLoad { image: map_expr!(image), coordinate: map_expr!(coordinate), array_index: map_expr_opt!(array_index), sample: map_expr_opt!(sample), level: map_expr_opt!(level), }, Expression::ImageQuery { image, query } => Expression::ImageQuery { image: map_expr!(image), query: match query { ImageQuery::Size { level } => ImageQuery::Size { level: map_expr_opt!(level), }, _ => *query, }, }, Expression::Unary { op, expr } => Expression::Unary { op: *op, expr: map_expr!(expr), }, Expression::Binary { op, left, right } => Expression::Binary { op: *op, left: map_expr!(left), right: map_expr!(right), }, Expression::Select { condition, accept, reject, } => Expression::Select { condition: map_expr!(condition), accept: map_expr!(accept), reject: map_expr!(reject), }, Expression::Derivative { axis, expr, ctrl } => Expression::Derivative { axis: *axis, expr: map_expr!(expr), ctrl: *ctrl, }, Expression::Relational { fun, argument } => Expression::Relational { fun: *fun, argument: map_expr!(argument), }, Expression::Math { fun, arg, arg1, arg2, arg3, } => Expression::Math { fun: *fun, arg: map_expr!(arg), arg1: map_expr_opt!(arg1), arg2: map_expr_opt!(arg2), arg3: map_expr_opt!(arg3), }, Expression::As { expr, kind, convert, } => Expression::As { expr: map_expr!(expr), kind: *kind, convert: *convert, }, Expression::ArrayLength(expr) => Expression::ArrayLength(map_expr!(expr)), Expression::LocalVariable(_) | Expression::FunctionArgument(_) => { is_external = true; expr.clone() } Expression::AtomicResult { ty, comparison } => Expression::AtomicResult { ty: self.import_type(ty), comparison: *comparison, }, Expression::WorkGroupUniformLoadResult { ty } => { Expression::WorkGroupUniformLoadResult { ty: self.import_type(ty), } } Expression::RayQueryProceedResult => expr.clone(), Expression::RayQueryGetIntersection { query, committed } => { Expression::RayQueryGetIntersection { query: map_expr!(query), committed: *committed, } } Expression::Override(h_override) => { is_external = true; Expression::Override(self.import_pipeline_override(h_override)) } Expression::SubgroupBallotResult => expr.clone(), Expression::SubgroupOperationResult { ty } => Expression::SubgroupOperationResult { ty: self.import_type(ty), }, }; if !non_emitting_only || is_external { let span = old_expressions.get_span(h_expr); let h_new = if unique { new_expressions.borrow_mut().fetch_if_or_append( expr, self.map_span(span), |lhs, rhs| lhs == rhs, ) } else { new_expressions .borrow_mut() .append(expr, self.map_span(span)) }; already_imported.borrow_mut().insert(h_expr, h_new); h_new } else { h_expr } } // remap function global references (global vars, consts, types) into our derived context pub fn localize_function(&mut self, func: &Function) -> Function { let arguments = func .arguments .iter() .map(|arg| FunctionArgument { name: arg.name.clone(), ty: self.import_type(&arg.ty), binding: arg.binding.clone(), }) .collect(); let result = func.result.as_ref().map(|r| FunctionResult { ty: self.import_type(&r.ty), binding: r.binding.clone(), }); let expressions = Rc::new(RefCell::new(Arena::new())); let expr_map = Rc::new(RefCell::new(IndexMap::new())); let mut local_variables = Arena::new(); for (h_l, l) in func.local_variables.iter() { let new_local = LocalVariable { name: l.name.clone(), ty: self.import_type(&l.ty), init: l.init.map(|c| { self.import_expression( c, &func.expressions, expr_map.clone(), expressions.clone(), false, true, ) }), }; let span = func.local_variables.get_span(h_l); let new_h = local_variables.append(new_local, self.map_span(span)); assert_eq!(h_l, new_h); } let body = self.import_block( &func.body, &func.expressions, expr_map.clone(), expressions.clone(), ); let named_expressions = func .named_expressions .iter() .flat_map(|(h_expr, name)| { expr_map .borrow() .get(h_expr) .map(|new_h| (*new_h, name.clone())) }) .collect::>>(); Function { name: func.name.clone(), arguments, result, local_variables, expressions: Rc::try_unwrap(expressions).unwrap().into_inner(), named_expressions, body, diagnostic_filter_leaf: None, } } // import a function defined in the source shader context. // func name may be already defined, the returned handle will refer to the new function. // the previously defined function will still be valid. pub fn import_function(&mut self, func: &Function, span: Span) -> Handle { let name = func.name.as_ref().unwrap().clone(); let mapped_func = self.localize_function(func); let new_span = self.map_span(span); let new_h = self.functions.append(mapped_func, new_span); self.function_map.insert(name, new_h); new_h } // get the derived handle corresponding to the given source function handle // requires func to be named pub fn map_function_handle(&mut self, h_func: &Handle) -> Handle { let functions = &self.shader.as_ref().unwrap().functions; let func = functions.try_get(*h_func).unwrap(); let name = func.name.as_ref().unwrap(); self.function_map.get(name).copied().unwrap_or_else(|| { let span = functions.get_span(*h_func); self.import_function(func, span) }) } /// swap an already imported function for a new one. /// note span cannot be updated pub fn import_function_if_new(&mut self, func: &Function, span: Span) -> Handle { let name = func.name.as_ref().unwrap().clone(); if let Some(h) = self.function_map.get(&name) { return *h; } self.import_function(func, span) } /// get any required special types for this module pub fn has_required_special_types(&self) -> bool { !self.special_types.predeclared_types.is_empty() || self.special_types.ray_desc.is_some() || self.special_types.ray_intersection.is_some() } pub fn into_module_with_entrypoints(mut self) -> naga::Module { let entry_points = self .shader .unwrap() .entry_points .iter() .map(|ep| EntryPoint { name: ep.name.clone(), stage: ep.stage, early_depth_test: ep.early_depth_test, workgroup_size: ep.workgroup_size, function: self.localize_function(&ep.function), workgroup_size_overrides: ep.workgroup_size_overrides, }) .collect(); naga::Module { entry_points, ..self.into() } } } impl From> for naga::Module { fn from(derived: DerivedModule) -> Self { naga::Module { types: derived.types, constants: derived.constants, global_variables: derived.globals, global_expressions: Rc::try_unwrap(derived.global_expressions) .unwrap() .into_inner(), functions: derived.functions, special_types: derived.special_types, entry_points: Default::default(), overrides: derived.pipeline_overrides, diagnostic_filters: Default::default(), diagnostic_filter_leaf: None, } } }