1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
//
// Copyright 2024, Colias Group, LLC
//
// SPDX-License-Identifier: BSD-2-Clause
//

use core::convert::Infallible;

use sel4_driver_interfaces::timer::{NumTimers, Timers};
use sel4_driver_interfaces::HandleInterrupt;
use sel4_microkit::{Channel, Handler, MessageInfo};
use sel4_microkit_message::MessageInfoExt;

use super::message_types::*;

#[derive(Clone, Debug)]
pub struct HandlerImpl<Driver> {
    driver: Driver,
    timer: Channel,
    client: Channel,
    num_timers: usize,
}

impl<Driver: Timers<TimerLayout = NumTimers>> HandlerImpl<Driver> {
    pub fn new(mut driver: Driver, timer: Channel, client: Channel) -> Result<Self, Driver::Error> {
        let num_timers = driver.timer_layout()?.0;
        Ok(Self {
            driver,
            timer,
            client,
            num_timers,
        })
    }

    fn guard_timer(&self, timer: usize) -> Result<(), ErrorResponse> {
        if timer < self.num_timers {
            Ok(())
        } else {
            Err(ErrorResponse::TimerOutOfBounds)
        }
    }
}

impl<Driver> Handler for HandlerImpl<Driver>
where
    Driver: Timers<TimerLayout = NumTimers, Timer = usize> + HandleInterrupt,
{
    type Error = Infallible;

    fn notified(&mut self, channel: Channel) -> Result<(), Self::Error> {
        if channel == self.timer {
            self.driver.handle_interrupt();
            self.timer.irq_ack().unwrap();
            self.client.notify();
        } else {
            panic!("unexpected channel: {channel:?}");
        }
        Ok(())
    }

    fn protected(
        &mut self,
        channel: Channel,
        msg_info: MessageInfo,
    ) -> Result<MessageInfo, Self::Error> {
        if channel == self.client {
            Ok(match msg_info.recv_using_postcard::<Request>() {
                Ok(req) => {
                    let resp = match req {
                        Request::GetTime => self
                            .driver
                            .get_time()
                            .map(SuccessResponse::GetTime)
                            .map_err(|_| ErrorResponse::Unspecified),
                        Request::NumTimers => self
                            .driver
                            .timer_layout()
                            .map(|NumTimers(n)| SuccessResponse::NumTimers(n))
                            .map_err(|_| ErrorResponse::Unspecified),
                        Request::SetTimeout { timer, relative } => {
                            self.guard_timer(timer).and_then(|_| {
                                self.driver
                                    .set_timeout_on(timer, relative)
                                    .map(|_| SuccessResponse::SetTimeout)
                                    .map_err(|_| ErrorResponse::Unspecified)
                            })
                        }
                        Request::ClearTimeout { timer } => self.guard_timer(timer).and_then(|_| {
                            self.driver
                                .clear_timeout_on(timer)
                                .map(|_| SuccessResponse::ClearTimeout)
                                .map_err(|_| ErrorResponse::Unspecified)
                        }),
                    };
                    MessageInfo::send_using_postcard(resp).unwrap()
                }
                Err(_) => MessageInfo::send_unspecified_error(),
            })
        } else {
            panic!("unexpected channel: {channel:?}");
        }
    }
}