serenity/
collector.rs

1// Or we'll get deprecation warnings from our own deprecated type (seriously Rust?)
2#![allow(deprecated)]
3
4use futures::future::pending;
5use futures::{Stream, StreamExt as _};
6
7use crate::gateway::{CollectorCallback, ShardMessenger};
8use crate::model::prelude::*;
9
10/// Fundamental collector function. All collector types in this module are just wrappers around
11/// this function.
12///
13/// Example: creating a collector stream over removed reactions
14/// ```rust
15/// # use std::time::Duration;
16/// # use futures::StreamExt as _;
17/// # use serenity::model::prelude::Event;
18/// # use serenity::gateway::ShardMessenger;
19/// # use serenity::collector::collect;
20/// # async fn example_(shard: &ShardMessenger) {
21/// let stream = collect(shard, |event| match event {
22///     Event::ReactionRemove(event) => Some(event.reaction.clone()),
23///     _ => None,
24/// });
25///
26/// stream
27///     .for_each(|reaction| async move {
28///         println!("{}: removed {}", reaction.channel_id, reaction.emoji);
29///     })
30///     .await;
31/// # }
32/// ```
33pub fn collect<T: Send + 'static>(
34    shard: &ShardMessenger,
35    extractor: impl Fn(&Event) -> Option<T> + Send + Sync + 'static,
36) -> impl Stream<Item = T> {
37    let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel();
38
39    // Register an event callback in the shard. It's kept alive as long as we return `true`
40    shard.add_collector(CollectorCallback(Box::new(move |event| match extractor(event) {
41        // If this event matches, we send it to the receiver stream
42        Some(item) => sender.send(item).is_ok(),
43        None => !sender.is_closed(),
44    })));
45
46    // Convert the mpsc Receiver into a Stream
47    futures::stream::poll_fn(move |cx| receiver.poll_recv(cx))
48}
49
50macro_rules! make_specific_collector {
51    (
52        $( #[ $($meta:tt)* ] )*
53        $collector_type:ident, $item_type:ident,
54        $extractor:pat => $extracted_item:ident,
55        $( $filter_name:ident: $filter_type:ty => $filter_passes:expr, )*
56    ) => {
57        #[doc = concat!("A [`", stringify!($collector_type), "`] receives [`", stringify!($item_type), "`]'s match the given filters for a set duration.")]
58        $( #[ $($meta)* ] )*
59        #[must_use]
60        pub struct $collector_type {
61            shard: ShardMessenger,
62            duration: Option<std::time::Duration>,
63            filter: Option<Box<dyn Fn(&$item_type) -> bool + Send + Sync>>,
64            $( $filter_name: Option<$filter_type>, )*
65        }
66
67        impl $collector_type {
68            /// Creates a new collector without any filters configured.
69            pub fn new(shard: impl AsRef<ShardMessenger>) -> Self {
70                Self {
71                    shard: shard.as_ref().clone(),
72                    duration: None,
73                    filter: None,
74                    $( $filter_name: None, )*
75                }
76            }
77
78            /// Sets a duration for how long the collector shall receive interactions.
79            pub fn timeout(mut self, duration: std::time::Duration) -> Self {
80                self.duration = Some(duration);
81                self
82            }
83
84            /// Sets a generic filter function.
85            pub fn filter(mut self, filter: impl Fn(&$item_type) -> bool + Send + Sync + 'static) -> Self {
86                self.filter = Some(Box::new(filter));
87                self
88            }
89
90            $(
91                #[doc = concat!("Filters [`", stringify!($item_type), "`]'s by a specific [`", stringify!($filter_type), "`].")]
92                pub fn $filter_name(mut self, $filter_name: $filter_type) -> Self {
93                    self.$filter_name = Some($filter_name);
94                    self
95                }
96            )*
97
98            #[doc = concat!("Returns a [`Stream`] over all collected [`", stringify!($item_type), "`].")]
99            pub fn stream(self) -> impl Stream<Item = $item_type> {
100                let filters_pass = move |$extracted_item: &$item_type| {
101                    // Check each of the built-in filters (author_id, channel_id, etc.)
102                    $( if let Some($filter_name) = &self.$filter_name {
103                        if !$filter_passes {
104                            return false;
105                        }
106                    } )*
107                    // Check the callback-based filter
108                    if let Some(custom_filter) = &self.filter {
109                        if !custom_filter($extracted_item) {
110                            return false;
111                        }
112                    }
113                    true
114                };
115
116                // A future that completes once the timeout is triggered
117                let timeout = async move { match self.duration {
118                    Some(d) => tokio::time::sleep(d).await,
119                    None => pending::<()>().await,
120                } };
121
122                let stream = collect(&self.shard, move |event| match event {
123                    $extractor if filters_pass($extracted_item) => Some($extracted_item.clone()),
124                    _ => None,
125                });
126                // Need to Box::pin this, or else users have to `pin_mut!()` the stream to the stack
127                stream.take_until(Box::pin(timeout))
128            }
129
130            /// Deprecated, use [`Self::stream()`] instead.
131            #[deprecated = "use `.stream()` instead"]
132            pub fn build(self) -> impl Stream<Item = $item_type> {
133                self.stream()
134            }
135
136            #[doc = concat!("Returns the next [`", stringify!($item_type), "`] which passes the filters.")]
137            #[doc = concat!("You can also call `.await` on the [`", stringify!($collector_type), "`] directly.")]
138            pub async fn next(self) -> Option<$item_type> {
139                self.stream().next().await
140            }
141        }
142
143        impl std::future::IntoFuture for $collector_type {
144            type Output = Option<$item_type>;
145            type IntoFuture = futures::future::BoxFuture<'static, Self::Output>;
146
147            fn into_future(self) -> Self::IntoFuture {
148                Box::pin(self.next())
149            }
150        }
151    };
152}
153
154make_specific_collector!(
155    // First line has name of the collector type, and the type of the collected items.
156    ComponentInteractionCollector, ComponentInteraction,
157    // This defines the extractor pattern, which extracts the data we want to collect from an Event.
158    Event::InteractionCreate(InteractionCreateEvent {
159        interaction: Interaction::Component(interaction),
160    }) => interaction,
161    // All following lines define built-in filters of the collector.
162    // Each line consists of:
163    // - the filter name (the name of the generated builder-like method on the collector type)
164    // - filter argument type (used as argument of the builder-like method on the collector type)
165    // - filter expression (this expressoin must return true to let the event through)
166    author_id: UserId => interaction.user.id == *author_id,
167    channel_id: ChannelId => interaction.channel_id == *channel_id,
168    guild_id: GuildId => interaction.guild_id.map_or(true, |x| x == *guild_id),
169    message_id: MessageId => interaction.message.id == *message_id,
170    custom_ids: Vec<String> => custom_ids.contains(&interaction.data.custom_id),
171);
172make_specific_collector!(
173    ModalInteractionCollector, ModalInteraction,
174    Event::InteractionCreate(InteractionCreateEvent {
175        interaction: Interaction::Modal(interaction),
176    }) => interaction,
177    author_id: UserId => interaction.user.id == *author_id,
178    channel_id: ChannelId => interaction.channel_id == *channel_id,
179    guild_id: GuildId => interaction.guild_id.map_or(true, |g| g == *guild_id),
180    message_id: MessageId => interaction.message.as_ref().map_or(true, |m| m.id == *message_id),
181    custom_ids: Vec<String> => custom_ids.contains(&interaction.data.custom_id),
182);
183make_specific_collector!(
184    ReactionCollector, Reaction,
185    Event::ReactionAdd(ReactionAddEvent { reaction }) => reaction,
186    author_id: UserId => reaction.user_id.map_or(true, |a| a == *author_id),
187    channel_id: ChannelId => reaction.channel_id == *channel_id,
188    guild_id: GuildId => reaction.guild_id.map_or(true, |g| g == *guild_id),
189    message_id: MessageId => reaction.message_id == *message_id,
190);
191make_specific_collector!(
192    MessageCollector, Message,
193    Event::MessageCreate(MessageCreateEvent { message }) => message,
194    author_id: UserId => message.author.id == *author_id,
195    channel_id: ChannelId => message.channel_id == *channel_id,
196    guild_id: GuildId => message.guild_id.map_or(true, |g| g == *guild_id),
197);
198make_specific_collector!(
199    #[deprecated = "prefer the stand-alone collect() function to collect arbitrary events"]
200    EventCollector, Event,
201    event => event,
202);