1#![allow(deprecated)]
3
4use futures::future::pending;
5use futures::{Stream, StreamExt as _};
6
7use crate::gateway::{CollectorCallback, ShardMessenger};
8use crate::model::prelude::*;
9
10pub fn collect<T: Send + 'static>(
34 shard: &ShardMessenger,
35 extractor: impl Fn(&Event) -> Option<T> + Send + Sync + 'static,
36) -> impl Stream<Item = T> {
37 let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel();
38
39 shard.add_collector(CollectorCallback(Box::new(move |event| match extractor(event) {
41 Some(item) => sender.send(item).is_ok(),
43 None => !sender.is_closed(),
44 })));
45
46 futures::stream::poll_fn(move |cx| receiver.poll_recv(cx))
48}
49
50macro_rules! make_specific_collector {
51 (
52 $( #[ $($meta:tt)* ] )*
53 $collector_type:ident, $item_type:ident,
54 $extractor:pat => $extracted_item:ident,
55 $( $filter_name:ident: $filter_type:ty => $filter_passes:expr, )*
56 ) => {
57 #[doc = concat!("A [`", stringify!($collector_type), "`] receives [`", stringify!($item_type), "`]'s match the given filters for a set duration.")]
58 $( #[ $($meta)* ] )*
59 #[must_use]
60 pub struct $collector_type {
61 shard: ShardMessenger,
62 duration: Option<std::time::Duration>,
63 filter: Option<Box<dyn Fn(&$item_type) -> bool + Send + Sync>>,
64 $( $filter_name: Option<$filter_type>, )*
65 }
66
67 impl $collector_type {
68 pub fn new(shard: impl AsRef<ShardMessenger>) -> Self {
70 Self {
71 shard: shard.as_ref().clone(),
72 duration: None,
73 filter: None,
74 $( $filter_name: None, )*
75 }
76 }
77
78 pub fn timeout(mut self, duration: std::time::Duration) -> Self {
80 self.duration = Some(duration);
81 self
82 }
83
84 pub fn filter(mut self, filter: impl Fn(&$item_type) -> bool + Send + Sync + 'static) -> Self {
86 self.filter = Some(Box::new(filter));
87 self
88 }
89
90 $(
91 #[doc = concat!("Filters [`", stringify!($item_type), "`]'s by a specific [`", stringify!($filter_type), "`].")]
92 pub fn $filter_name(mut self, $filter_name: $filter_type) -> Self {
93 self.$filter_name = Some($filter_name);
94 self
95 }
96 )*
97
98 #[doc = concat!("Returns a [`Stream`] over all collected [`", stringify!($item_type), "`].")]
99 pub fn stream(self) -> impl Stream<Item = $item_type> {
100 let filters_pass = move |$extracted_item: &$item_type| {
101 $( if let Some($filter_name) = &self.$filter_name {
103 if !$filter_passes {
104 return false;
105 }
106 } )*
107 if let Some(custom_filter) = &self.filter {
109 if !custom_filter($extracted_item) {
110 return false;
111 }
112 }
113 true
114 };
115
116 let timeout = async move { match self.duration {
118 Some(d) => tokio::time::sleep(d).await,
119 None => pending::<()>().await,
120 } };
121
122 let stream = collect(&self.shard, move |event| match event {
123 $extractor if filters_pass($extracted_item) => Some($extracted_item.clone()),
124 _ => None,
125 });
126 stream.take_until(Box::pin(timeout))
128 }
129
130 #[deprecated = "use `.stream()` instead"]
132 pub fn build(self) -> impl Stream<Item = $item_type> {
133 self.stream()
134 }
135
136 #[doc = concat!("Returns the next [`", stringify!($item_type), "`] which passes the filters.")]
137 #[doc = concat!("You can also call `.await` on the [`", stringify!($collector_type), "`] directly.")]
138 pub async fn next(self) -> Option<$item_type> {
139 self.stream().next().await
140 }
141 }
142
143 impl std::future::IntoFuture for $collector_type {
144 type Output = Option<$item_type>;
145 type IntoFuture = futures::future::BoxFuture<'static, Self::Output>;
146
147 fn into_future(self) -> Self::IntoFuture {
148 Box::pin(self.next())
149 }
150 }
151 };
152}
153
154make_specific_collector!(
155 ComponentInteractionCollector, ComponentInteraction,
157 Event::InteractionCreate(InteractionCreateEvent {
159 interaction: Interaction::Component(interaction),
160 }) => interaction,
161 author_id: UserId => interaction.user.id == *author_id,
167 channel_id: ChannelId => interaction.channel_id == *channel_id,
168 guild_id: GuildId => interaction.guild_id.map_or(true, |x| x == *guild_id),
169 message_id: MessageId => interaction.message.id == *message_id,
170 custom_ids: Vec<String> => custom_ids.contains(&interaction.data.custom_id),
171);
172make_specific_collector!(
173 ModalInteractionCollector, ModalInteraction,
174 Event::InteractionCreate(InteractionCreateEvent {
175 interaction: Interaction::Modal(interaction),
176 }) => interaction,
177 author_id: UserId => interaction.user.id == *author_id,
178 channel_id: ChannelId => interaction.channel_id == *channel_id,
179 guild_id: GuildId => interaction.guild_id.map_or(true, |g| g == *guild_id),
180 message_id: MessageId => interaction.message.as_ref().map_or(true, |m| m.id == *message_id),
181 custom_ids: Vec<String> => custom_ids.contains(&interaction.data.custom_id),
182);
183make_specific_collector!(
184 ReactionCollector, Reaction,
185 Event::ReactionAdd(ReactionAddEvent { reaction }) => reaction,
186 author_id: UserId => reaction.user_id.map_or(true, |a| a == *author_id),
187 channel_id: ChannelId => reaction.channel_id == *channel_id,
188 guild_id: GuildId => reaction.guild_id.map_or(true, |g| g == *guild_id),
189 message_id: MessageId => reaction.message_id == *message_id,
190);
191make_specific_collector!(
192 MessageCollector, Message,
193 Event::MessageCreate(MessageCreateEvent { message }) => message,
194 author_id: UserId => message.author.id == *author_id,
195 channel_id: ChannelId => message.channel_id == *channel_id,
196 guild_id: GuildId => message.guild_id.map_or(true, |g| g == *guild_id),
197);
198make_specific_collector!(
199 #[deprecated = "prefer the stand-alone collect() function to collect arbitrary events"]
200 EventCollector, Event,
201 event => event,
202);