1use std::collections::VecDeque;
8
9use crate::{
10 component::ComponentId,
11 layout::{LayoutNode, Point, Rect},
12};
13
14use super::EventError;
15
16#[derive(Debug)]
18pub struct HitTester {
19 pub stats: HitTestStats,
21}
22
23#[derive(Debug, Default, Clone)]
25pub struct HitTestStats {
26 pub hit_tests: u64,
28 pub hit_test_time_us: u64,
30 pub nodes_tested: u32,
32 pub hits_found: u32,
34}
35
36impl HitTester {
37 pub fn new() -> Self {
39 Self {
40 stats: HitTestStats::default(),
41 }
42 }
43
44 pub fn hit_test(
47 &mut self,
48 point: Point,
49 layout_root: &LayoutNode,
50 ) -> Result<Vec<ComponentId>, EventError> {
51 let start_time = std::time::Instant::now();
52 self.stats.nodes_tested = 0;
53 self.stats.hits_found = 0;
54
55 let mut hits = Vec::new();
56 self.hit_test_recursive(point, layout_root, &mut hits)?;
57
58 let elapsed = start_time.elapsed();
60 self.stats.hit_tests += 1;
61 let elapsed_us = std::cmp::max(1, elapsed.as_micros() as u64);
63 self.stats.hit_test_time_us += elapsed_us;
64 self.stats.hits_found = hits.len() as u32;
65
66 Ok(hits)
67 }
68
69 fn hit_test_recursive(
71 &mut self,
72 point: Point,
73 node: &LayoutNode,
74 hits: &mut Vec<ComponentId>,
75 ) -> Result<(), EventError> {
76 self.stats.nodes_tested += 1;
77
78 if node.layout.rect.contains_point(point) {
80 hits.insert(0, node.id);
82
83 for child in node.children.iter().rev() {
85 self.hit_test_recursive(point, child, hits)?;
86 }
87 }
88
89 Ok(())
90 }
91
92 pub fn hit_test_depth_first(
94 &mut self,
95 point: Point,
96 layout_root: &LayoutNode,
97 ) -> Result<Vec<ComponentId>, EventError> {
98 let start_time = std::time::Instant::now();
99 self.stats.nodes_tested = 0;
100 self.stats.hits_found = 0;
101
102 let mut hits = Vec::new();
103 let mut stack = VecDeque::new();
104 stack.push_back(layout_root);
105
106 while let Some(node) = stack.pop_back() {
107 self.stats.nodes_tested += 1;
108
109 if node.layout.rect.contains_point(point) {
110 hits.push(node.id);
111
112 for child in node.children.iter().rev() {
114 stack.push_back(child);
115 }
116 }
117 }
118
119 let elapsed = start_time.elapsed();
121 self.stats.hit_tests += 1;
122 let elapsed_us = std::cmp::max(1, elapsed.as_micros() as u64);
124 self.stats.hit_test_time_us += elapsed_us;
125 self.stats.hits_found = hits.len() as u32;
126
127 Ok(hits)
128 }
129
130 pub fn hit_test_top(
132 &mut self,
133 point: Point,
134 layout_root: &LayoutNode,
135 ) -> Result<Option<ComponentId>, EventError> {
136 let hits = self.hit_test(point, layout_root)?;
137 Ok(hits.first().copied())
138 }
139
140 pub fn hit_test_region(
142 &mut self,
143 region: Rect,
144 layout_root: &LayoutNode,
145 ) -> Result<Vec<ComponentId>, EventError> {
146 let start_time = std::time::Instant::now();
147 self.stats.nodes_tested = 0;
148 self.stats.hits_found = 0;
149
150 let mut hits = Vec::new();
151 self.hit_test_region_recursive(region, layout_root, &mut hits)?;
152
153 let elapsed = start_time.elapsed();
155 self.stats.hit_tests += 1;
156 let elapsed_us = std::cmp::max(1, elapsed.as_micros() as u64);
158 self.stats.hit_test_time_us += elapsed_us;
159 self.stats.hits_found = hits.len() as u32;
160
161 Ok(hits)
162 }
163
164 fn hit_test_region_recursive(
166 &mut self,
167 region: Rect,
168 node: &LayoutNode,
169 hits: &mut Vec<ComponentId>,
170 ) -> Result<(), EventError> {
171 self.stats.nodes_tested += 1;
172
173 if self.rect_intersects(region, node.layout.rect) {
175 hits.push(node.id);
176
177 for child in &node.children {
179 self.hit_test_region_recursive(region, child, hits)?;
180 }
181 }
182
183 Ok(())
184 }
185
186 fn rect_intersects(&self, rect1: Rect, rect2: Rect) -> bool {
188 rect1.x() < rect2.max_x()
189 && rect1.max_x() > rect2.x()
190 && rect1.y() < rect2.max_y()
191 && rect1.max_y() > rect2.y()
192 }
193
194 pub fn reset_stats(&mut self) {
196 self.stats = HitTestStats::default();
197 }
198
199 pub fn get_stats(&self) -> &HitTestStats {
201 &self.stats
202 }
203}
204
205impl Default for HitTester {
206 fn default() -> Self {
207 Self::new()
208 }
209}
210
211impl std::fmt::Display for HitTestStats {
212 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213 write!(
214 f,
215 "Hit Test Stats: {} tests, {}μs total time, {} nodes tested, {} hits found",
216 self.hit_tests, self.hit_test_time_us, self.nodes_tested, self.hits_found
217 )
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use crate::{
225 component::ComponentId,
226 layout::{LayoutNode, LayoutStyle},
227 };
228
229 fn create_test_layout() -> LayoutNode {
230 let root_id = ComponentId::new();
231 let root_style = LayoutStyle {
232 width: crate::layout::Dimension::Points(400.0),
233 height: crate::layout::Dimension::Points(300.0),
234 ..Default::default()
235 };
236 let mut root = LayoutNode::new(root_id, root_style);
237
238 root.layout.rect = Rect::new(0.0, 0.0, 400.0, 300.0);
240
241 let child1_id = ComponentId::new();
243 let child1_style = LayoutStyle::default();
244 let mut child1 = LayoutNode::new(child1_id, child1_style);
245 child1.layout.rect = Rect::new(0.0, 0.0, 200.0, 150.0);
246
247 let child2_id = ComponentId::new();
249 let child2_style = LayoutStyle::default();
250 let mut child2 = LayoutNode::new(child2_id, child2_style);
251 child2.layout.rect = Rect::new(200.0, 150.0, 200.0, 150.0);
252
253 let nested_id = ComponentId::new();
255 let nested_style = LayoutStyle::default();
256 let mut nested = LayoutNode::new(nested_id, nested_style);
257 nested.layout.rect = Rect::new(50.0, 50.0, 100.0, 50.0);
258
259 child1.add_child(nested);
260 root.add_child(child1);
261 root.add_child(child2);
262
263 root
264 }
265
266 #[test]
267 fn test_hit_tester_creation() {
268 let hit_tester = HitTester::new();
269 assert_eq!(hit_tester.stats.hit_tests, 0);
270 assert_eq!(hit_tester.stats.nodes_tested, 0);
271 }
272
273 #[test]
274 fn test_simple_hit_test() {
275 let mut hit_tester = HitTester::new();
276 let layout = create_test_layout();
277
278 let point = Point::new(300.0, 50.0);
280 let hits = hit_tester.hit_test(point, &layout).unwrap();
281
282 assert_eq!(hits.len(), 1);
283 assert_eq!(hits[0], layout.id);
284 }
285
286 #[test]
287 fn test_nested_hit_test() {
288 let mut hit_tester = HitTester::new();
289 let layout = create_test_layout();
290
291 let point = Point::new(100.0, 75.0);
293 let hits = hit_tester.hit_test(point, &layout).unwrap();
294
295 assert_eq!(hits.len(), 3); assert_eq!(hits[0], layout.children[0].children[0].id); assert_eq!(hits[1], layout.children[0].id); assert_eq!(hits[2], layout.id); }
302
303 #[test]
304 fn test_hit_test_top() {
305 let mut hit_tester = HitTester::new();
306 let layout = create_test_layout();
307
308 let point = Point::new(100.0, 75.0);
310 let top_hit = hit_tester.hit_test_top(point, &layout).unwrap();
311
312 assert_eq!(top_hit, Some(layout.children[0].children[0].id));
313 }
314
315 #[test]
316 fn test_hit_test_miss() {
317 let mut hit_tester = HitTester::new();
318 let layout = create_test_layout();
319
320 let point = Point::new(500.0, 500.0);
322 let hits = hit_tester.hit_test(point, &layout).unwrap();
323
324 assert!(hits.is_empty());
325 }
326
327 #[test]
328 fn test_hit_test_region() {
329 let mut hit_tester = HitTester::new();
330 let layout = create_test_layout();
331
332 let region = Rect::new(150.0, 100.0, 100.0, 100.0);
334 let hits = hit_tester.hit_test_region(region, &layout).unwrap();
335
336 assert!(hits.len() >= 2);
338 assert!(hits.contains(&layout.id));
339 }
340
341 #[test]
342 fn test_hit_test_depth_first() {
343 let mut hit_tester = HitTester::new();
344 let layout = create_test_layout();
345
346 let point = Point::new(100.0, 75.0);
348 let hits = hit_tester.hit_test_depth_first(point, &layout).unwrap();
349
350 assert!(!hits.is_empty());
351 assert!(hits.contains(&layout.id));
352 assert!(hits.contains(&layout.children[0].id));
353 assert!(hits.contains(&layout.children[0].children[0].id));
354 }
355
356 #[test]
357 fn test_hit_test_statistics() {
358 let mut hit_tester = HitTester::new();
359 let layout = create_test_layout();
360
361 let point = Point::new(100.0, 75.0);
362 hit_tester.hit_test(point, &layout).unwrap();
363
364 assert_eq!(hit_tester.stats.hit_tests, 1);
365 assert!(hit_tester.stats.nodes_tested > 0);
366 assert!(hit_tester.stats.hit_test_time_us > 0);
367 assert_eq!(hit_tester.stats.hits_found, 3); let stats_str = format!("{}", hit_tester.stats);
369 assert!(stats_str.contains("1 tests"));
370 assert!(stats_str.contains("3 hits found"));
371 }
372
373 #[test]
374 fn test_rect_intersection() {
375 let hit_tester = HitTester::new();
376
377 let rect1 = Rect::new(0.0, 0.0, 100.0, 100.0);
378 let rect2 = Rect::new(50.0, 50.0, 100.0, 100.0);
379 let rect3 = Rect::new(200.0, 200.0, 100.0, 100.0);
380
381 assert!(hit_tester.rect_intersects(rect1, rect2));
382 assert!(hit_tester.rect_intersects(rect2, rect1));
383 assert!(!hit_tester.rect_intersects(rect1, rect3));
384 assert!(!hit_tester.rect_intersects(rect3, rect1));
385 }
386}