1#[cfg(any(all(feature = "test", feature = "kobo"), doc))]
38mod dbus_monitor;
39pub mod dictionary_index;
40#[cfg(any(feature = "test", doc))]
41mod hello_world;
42pub mod import;
43#[cfg(any(feature = "kobo", doc))]
44mod wifi_status_monitor;
45
46use std::collections::HashMap;
47use std::sync::atomic::{AtomicBool, Ordering};
48use std::sync::mpsc::{self, Receiver, Sender};
49use std::thread::{self, JoinHandle};
50use std::time::Duration;
51
52use thiserror::Error;
53
54use crate::db::Database;
55use crate::settings::Settings;
56use crate::view::Event;
57
58#[derive(Error, Debug)]
60pub enum TaskError {
61 #[error("task '{0}' is already running")]
63 AlreadyRunning(TaskId),
64
65 #[error("task '{0}' is not running")]
67 NotRunning(TaskId),
68}
69
70#[derive(Debug, Clone, PartialEq, Eq, Hash)]
72pub enum TaskId {
73 Placeholder,
75 Import,
77 DictionaryIndex,
79 #[cfg(any(feature = "test", doc))]
81 HelloWorld,
82 #[cfg(any(all(feature = "test", feature = "kobo"), doc))]
84 DbusMonitor,
85 #[cfg(any(feature = "kobo", doc))]
87 WifiStatusMonitor,
88 #[cfg(test)]
90 TestTask,
91 #[cfg(test)]
93 TestTask2,
94}
95
96impl std::fmt::Display for TaskId {
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 match self {
99 TaskId::Placeholder => write!(f, "placeholder"),
100 TaskId::Import => write!(f, "import"),
101 TaskId::DictionaryIndex => write!(f, "dictionary_index"),
102 #[cfg(feature = "test")]
103 TaskId::HelloWorld => write!(f, "hello_world"),
104 #[cfg(all(feature = "test", feature = "kobo"))]
105 TaskId::DbusMonitor => write!(f, "dbus_monitor"),
106 #[cfg(feature = "kobo")]
107 TaskId::WifiStatusMonitor => write!(f, "wifi_status_monitor"),
108 #[cfg(test)]
109 TaskId::TestTask => write!(f, "test_task"),
110 #[cfg(test)]
111 TaskId::TestTask2 => write!(f, "test_task_2"),
112 }
113 }
114}
115
116pub struct ShutdownSignal {
121 receiver: Receiver<()>,
122 _sender_anchor: Option<Sender<()>>,
125 stopped: AtomicBool,
126}
127
128impl ShutdownSignal {
129 fn new(receiver: Receiver<()>) -> Self {
130 Self {
131 receiver,
132 _sender_anchor: None,
133 stopped: AtomicBool::new(false),
134 }
135 }
136
137 pub fn never() -> Self {
142 let (tx, rx) = mpsc::channel();
143 Self {
144 receiver: rx,
145 _sender_anchor: Some(tx),
146 stopped: AtomicBool::new(false),
147 }
148 }
149
150 #[cfg(test)]
156 pub fn new_for_test(receiver: Receiver<()>) -> Self {
157 Self::new(receiver)
158 }
159
160 pub fn should_stop(&self) -> bool {
166 if self.stopped.load(Ordering::Acquire) {
167 return true;
168 }
169 if self.receiver.try_recv().is_ok() {
170 self.stopped.store(true, Ordering::Release);
171 return true;
172 }
173 false
174 }
175
176 pub fn wait(&self, duration: Duration) -> bool {
182 if self.stopped.load(Ordering::Acquire) {
183 return true;
184 }
185 match self.receiver.recv_timeout(duration) {
186 Ok(()) | Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
187 self.stopped.store(true, Ordering::Release);
188 true
189 }
190 Err(std::sync::mpsc::RecvTimeoutError::Timeout) => false,
191 }
192 }
193}
194
195pub trait BackgroundTask: Send {
201 fn id(&self) -> TaskId;
203
204 fn run(&mut self, hub: &Sender<Event>, shutdown: &ShutdownSignal);
209
210 fn stop(&mut self) {}
214}
215
216struct RunningTask {
217 handle: JoinHandle<()>,
218 shutdown: Sender<()>,
219}
220
221pub struct TaskManager {
226 tasks: HashMap<TaskId, RunningTask>,
227 pending_import_rerun: Option<Option<usize>>,
230}
231
232impl TaskManager {
233 pub fn new() -> Self {
235 Self {
236 tasks: HashMap::new(),
237 pending_import_rerun: None,
238 }
239 }
240
241 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, task, hub), fields(task_id = tracing::field::Empty), ret))]
248 pub fn start(
249 &mut self,
250 task: Box<dyn BackgroundTask>,
251 hub: Sender<Event>,
252 ) -> Result<TaskId, TaskError> {
253 let id = task.id();
254
255 #[cfg(feature = "tracing")]
256 tracing::Span::current().record("task_id", tracing::field::display(&id));
257
258 if self.is_running(&id) {
259 return Err(TaskError::AlreadyRunning(id));
260 }
261
262 let (shutdown_tx, shutdown_rx) = mpsc::channel();
263 let shutdown_signal = ShutdownSignal::new(shutdown_rx);
264
265 let handle = thread::spawn(move || {
266 let mut task = task;
267 tracing::info!("task started");
268 task.run(&hub, &shutdown_signal);
269 task.stop();
270 tracing::info!("task stopped");
271 });
272
273 self.tasks.insert(
274 id.clone(),
275 RunningTask {
276 handle,
277 shutdown: shutdown_tx,
278 },
279 );
280
281 tracing::info!("task registered");
282 Ok(id)
283 }
284
285 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), fields(task_id = %id), ret))]
290 pub fn stop(&mut self, id: &TaskId) -> Result<(), TaskError> {
291 self.cleanup_finished();
292 if let Some(task) = self.tasks.remove(id) {
293 tracing::info!("sending shutdown signal");
294 if let Err(e) = task.shutdown.send(()) {
295 tracing::error!(error = %e, "failed to send shutdown signal");
296 }
297 if task.handle.join().is_err() {
298 tracing::error!("task thread panicked");
299 }
300 Ok(())
301 } else {
302 Err(TaskError::NotRunning(id.clone()))
303 }
304 }
305
306 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), fields(task_count = tracing::field::Empty)))]
310 pub fn stop_all(&mut self) {
311 let tasks: Vec<_> = self.tasks.drain().collect();
312
313 #[cfg(feature = "tracing")]
314 tracing::Span::current().record("task_count", tasks.len());
315
316 if !tasks.is_empty() {
317 tracing::info!("stopping all tasks");
318 }
319 for (_, task) in &tasks {
320 if let Err(e) = task.shutdown.send(()) {
321 tracing::error!(error = %e, "failed to send shutdown signal");
322 }
323 }
324 for (_, task) in tasks {
325 if task.handle.join().is_err() {
326 tracing::error!("task thread panicked");
327 }
328 }
329 }
330
331 fn cleanup_finished(&mut self) {
333 self.tasks.retain(|_, task| !task.handle.is_finished());
334 }
335
336 #[cfg_attr(
341 feature = "tracing",
342 tracing::instrument(skip(self, hub, database, settings))
343 )]
344 pub fn handle_event(
345 &mut self,
346 evt: &Event,
347 hub: &Sender<Event>,
348 database: &Database,
349 settings: &Settings,
350 ) -> bool {
351 match evt {
352 Event::ImportLibrary { library_index } => {
353 self.schedule_import(*library_index, hub, database, settings);
354 }
355 Event::ImportFinished { .. } => {
356 if let Some(pending) = self.pending_import_rerun.take() {
357 self.schedule_import(pending, hub, database, settings);
358 }
359 }
360 Event::ReindexDictionaries => {
361 self.schedule_dictionary_index(hub, database);
362 }
363 _ => {}
364 }
365 false
366 }
367
368 #[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
370 fn schedule_import(
371 &mut self,
372 library_index: Option<usize>,
373 hub: &Sender<Event>,
374 database: &Database,
375 settings: &Settings,
376 ) {
377 if self.is_running(&TaskId::Import) {
378 self.pending_import_rerun = Some(library_index);
379 return;
380 }
381
382 let task = Box::new(import::ImportTask::new(
383 database.clone(),
384 settings.clone(),
385 library_index,
386 ));
387
388 if let Err(e) = self.start(task, hub.clone()) {
389 tracing::warn!(error = %e, "failed to start import task");
390 }
391 }
392
393 #[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
395 fn schedule_dictionary_index(&mut self, hub: &Sender<Event>, database: &Database) {
396 if self.is_running(&TaskId::DictionaryIndex) {
397 tracing::debug!("stopping running dictionary index task for restart");
398 if let Err(e) = self.stop(&TaskId::DictionaryIndex) {
399 tracing::warn!(error = %e, "failed to stop dictionary_index task for restart");
400 }
401 }
402
403 let task = Box::new(dictionary_index::DictionaryIndexTask::new(database.clone()));
404
405 if let Err(e) = self.start(task, hub.clone()) {
406 tracing::warn!(error = %e, "failed to start dictionary_index task");
407 }
408 }
409
410 pub fn is_running(&mut self, id: &TaskId) -> bool {
412 self.cleanup_finished();
413 self.tasks.contains_key(id)
414 }
415
416 pub fn running_tasks(&mut self) -> Vec<TaskId> {
418 self.cleanup_finished();
419 self.tasks.keys().cloned().collect()
420 }
421}
422
423impl Default for TaskManager {
424 fn default() -> Self {
425 Self::new()
426 }
427}
428
429impl Drop for TaskManager {
430 fn drop(&mut self) {
431 self.stop_all();
432 }
433}
434
435pub fn register_startup_tasks(
445 manager: &mut TaskManager,
446 hub: Sender<Event>,
447 settings: &Settings,
448 database: &Database,
449) {
450 #[cfg(feature = "kobo")]
451 {
452 let task = Box::new(wifi_status_monitor::WifiStatusMonitorTask);
453 if let Err(e) = manager.start(task, hub.clone()) {
454 tracing::warn!(error = %e, "failed to start wifi_status_monitor task");
455 }
456 }
457
458 #[cfg(feature = "test")]
459 {
460 let task = Box::new(hello_world::HelloWorldTask);
461 if let Err(e) = manager.start(task, hub.clone()) {
462 tracing::warn!(error = %e, "failed to start hello_world task");
463 }
464
465 #[cfg(feature = "kobo")]
466 if settings.logging.enable_dbus_log {
467 let task = Box::new(dbus_monitor::DbusMonitorTask);
468 if let Err(e) = manager.start(task, hub.clone()) {
469 tracing::warn!(error = %e, "failed to start dbus_monitor task");
470 }
471 }
472 }
473
474 if settings.import.startup_trigger {
475 manager.schedule_import(None, &hub, database, settings);
476 }
477
478 let task = Box::new(dictionary_index::DictionaryIndexTask::new(database.clone()));
479 if let Err(e) = manager.start(task, hub.clone()) {
480 tracing::warn!(error = %e, "failed to start dictionary_index task");
481 }
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487 use std::sync::mpsc;
488 use std::time::{Duration, Instant};
489
490 fn wait_until_not_running(manager: &mut TaskManager, id: &TaskId) {
491 let deadline = Instant::now() + Duration::from_secs(5);
492 while Instant::now() < deadline {
493 if !manager.is_running(id) {
494 return;
495 }
496 std::thread::sleep(Duration::from_millis(1));
497 }
498 panic!("task '{id}' did not finish within timeout");
499 }
500
501 struct InstantTask;
502
503 impl BackgroundTask for InstantTask {
504 fn id(&self) -> TaskId {
505 TaskId::TestTask2
506 }
507
508 fn run(&mut self, _hub: &Sender<Event>, _shutdown: &ShutdownSignal) {}
509 }
510
511 struct WaitingTask;
512
513 impl BackgroundTask for WaitingTask {
514 fn id(&self) -> TaskId {
515 TaskId::TestTask
516 }
517
518 fn run(&mut self, _hub: &Sender<Event>, shutdown: &ShutdownSignal) {
519 shutdown.wait(Duration::from_secs(60));
520 }
521 }
522
523 #[test]
524 fn start_and_stop() {
525 let mut manager = TaskManager::new();
526 let (hub, _rx) = mpsc::channel();
527
528 let id = manager.start(Box::new(WaitingTask), hub).unwrap();
529 assert!(manager.is_running(&id));
530
531 manager.stop(&id).unwrap();
532 assert!(!manager.is_running(&id));
533 }
534
535 #[test]
536 fn duplicate_start_returns_error() {
537 let mut manager = TaskManager::new();
538 let (hub, _rx) = mpsc::channel();
539
540 manager.start(Box::new(WaitingTask), hub.clone()).unwrap();
541 let err = manager.start(Box::new(WaitingTask), hub).unwrap_err();
542
543 assert!(matches!(err, TaskError::AlreadyRunning(TaskId::TestTask)));
544 }
545
546 #[test]
547 fn finished_task_is_cleaned_up() {
548 let mut manager = TaskManager::new();
549 let (hub, _rx) = mpsc::channel();
550
551 let id = manager.start(Box::new(InstantTask), hub).unwrap();
552
553 wait_until_not_running(&mut manager, &id);
554 assert!(!manager.is_running(&id));
555 }
556
557 #[test]
558 fn stop_finished_task_returns_not_running() {
559 let mut manager = TaskManager::new();
560 let (hub, _rx) = mpsc::channel();
561
562 let id = manager.start(Box::new(InstantTask), hub).unwrap();
563
564 wait_until_not_running(&mut manager, &id);
565 let err = manager.stop(&id).unwrap_err();
566
567 assert!(matches!(err, TaskError::NotRunning(TaskId::TestTask2)));
568 }
569
570 #[test]
571 fn running_tasks_excludes_finished() {
572 let mut manager = TaskManager::new();
573 let (hub, _rx) = mpsc::channel();
574
575 manager.start(Box::new(WaitingTask), hub.clone()).unwrap();
576 let instant_id = manager.start(Box::new(InstantTask), hub).unwrap();
577
578 wait_until_not_running(&mut manager, &instant_id);
579 let running = manager.running_tasks();
580
581 assert_eq!(running.len(), 1);
582 assert_eq!(running[0], TaskId::TestTask);
583
584 manager.stop_all();
585 }
586
587 #[test]
588 fn stop_all_stops_everything() {
589 let mut manager = TaskManager::new();
590 let (hub, _rx) = mpsc::channel();
591
592 manager.start(Box::new(WaitingTask), hub).unwrap();
593 manager.stop_all();
594
595 assert!(!manager.is_running(&TaskId::TestTask));
596 }
597}