1use std::borrow::Cow;
2use std::sync::Arc;
3
4use futures::channel::mpsc::{self, UnboundedReceiver as Receiver, UnboundedSender as Sender};
5use tokio::sync::RwLock;
6use tokio_tungstenite::tungstenite;
7use tokio_tungstenite::tungstenite::error::Error as TungsteniteError;
8use tokio_tungstenite::tungstenite::protocol::frame::CloseFrame;
9use tracing::{debug, error, info, instrument, trace, warn};
10use typemap_rev::TypeMap;
11
12use super::event::ShardStageUpdateEvent;
13#[cfg(feature = "collector")]
14use super::CollectorCallback;
15#[cfg(feature = "voice")]
16use super::VoiceGatewayManager;
17use super::{ShardId, ShardManager, ShardRunnerMessage};
18#[cfg(feature = "cache")]
19use crate::cache::Cache;
20use crate::client::dispatch::dispatch_model;
21use crate::client::{Context, EventHandler, RawEventHandler};
22#[cfg(feature = "framework")]
23use crate::framework::Framework;
24use crate::gateway::{GatewayError, ReconnectType, Shard, ShardAction};
25use crate::http::Http;
26use crate::internal::prelude::*;
27use crate::internal::tokio::spawn_named;
28use crate::model::event::{Event, GatewayEvent};
29
30pub struct ShardRunner {
32 data: Arc<RwLock<TypeMap>>,
33 event_handlers: Vec<Arc<dyn EventHandler>>,
34 raw_event_handlers: Vec<Arc<dyn RawEventHandler>>,
35 #[cfg(feature = "framework")]
36 framework: Option<Arc<dyn Framework>>,
37 manager: Arc<ShardManager>,
38 runner_rx: Receiver<ShardRunnerMessage>,
40 runner_tx: Sender<ShardRunnerMessage>,
42 pub(crate) shard: Shard,
43 #[cfg(feature = "voice")]
44 voice_manager: Option<Arc<dyn VoiceGatewayManager + 'static>>,
45 #[cfg(feature = "cache")]
46 pub cache: Arc<Cache>,
47 pub http: Arc<Http>,
48 #[cfg(feature = "collector")]
49 pub(crate) collectors: Arc<std::sync::Mutex<Vec<CollectorCallback>>>,
50}
51
52impl ShardRunner {
53 pub fn new(opt: ShardRunnerOptions) -> Self {
55 let (tx, rx) = mpsc::unbounded();
56
57 Self {
58 runner_rx: rx,
59 runner_tx: tx,
60 data: opt.data,
61 event_handlers: opt.event_handlers,
62 raw_event_handlers: opt.raw_event_handlers,
63 #[cfg(feature = "framework")]
64 framework: opt.framework,
65 manager: opt.manager,
66 shard: opt.shard,
67 #[cfg(feature = "voice")]
68 voice_manager: opt.voice_manager,
69 #[cfg(feature = "cache")]
70 cache: opt.cache,
71 http: opt.http,
72 #[cfg(feature = "collector")]
73 collectors: Arc::new(std::sync::Mutex::new(vec![])),
74 }
75 }
76
77 #[instrument(skip(self))]
98 pub async fn run(&mut self) -> Result<()> {
99 info!("[ShardRunner {:?}] Running", self.shard.shard_info());
100
101 loop {
102 trace!("[ShardRunner {:?}] loop iteration started.", self.shard.shard_info());
103 if !self.recv().await {
104 return Ok(());
105 }
106
107 if !self.shard.do_heartbeat().await {
109 warn!("[ShardRunner {:?}] Error heartbeating", self.shard.shard_info(),);
110
111 self.request_restart().await;
112 return Ok(());
113 }
114
115 let pre = self.shard.stage();
116 let (event, action, successful) = self.recv_event().await?;
117 let post = self.shard.stage();
118
119 if post != pre {
120 self.update_manager().await;
121
122 for event_handler in self.event_handlers.clone() {
123 let context = self.make_context();
124 let event = ShardStageUpdateEvent {
125 new: post,
126 old: pre,
127 shard_id: self.shard.shard_info().id,
128 };
129 spawn_named("dispatch::event_handler::shard_stage_update", async move {
130 event_handler.shard_stage_update(context, event).await;
131 });
132 }
133 }
134
135 match action {
136 Some(ShardAction::Reconnect(ReconnectType::Reidentify)) => {
137 self.request_restart().await;
138 return Ok(());
139 },
140 Some(other) => {
141 if let Err(e) = self.action(&other).await {
142 debug!(
143 "[ShardRunner {:?}] Reconnecting due to error performing {:?}: {:?}",
144 self.shard.shard_info(),
145 other,
146 e
147 );
148 match self.shard.reconnection_type() {
149 ReconnectType::Reidentify => {
150 self.request_restart().await;
151 return Ok(());
152 },
153 ReconnectType::Resume => {
154 if let Err(why) = self.shard.resume().await {
155 warn!(
156 "[ShardRunner {:?}] Resume failed, reidentifying: {:?}",
157 self.shard.shard_info(),
158 why
159 );
160
161 self.request_restart().await;
162 return Ok(());
163 }
164 },
165 };
166 }
167 },
168 None => {},
169 }
170
171 if let Some(event) = event {
172 #[cfg(feature = "collector")]
173 self.collectors.lock().expect("poison").retain_mut(|callback| (callback.0)(&event));
174
175 dispatch_model(
176 event,
177 &self.make_context(),
178 #[cfg(feature = "framework")]
179 self.framework.clone(),
180 self.event_handlers.clone(),
181 self.raw_event_handlers.clone(),
182 );
183 }
184
185 if !successful && !self.shard.stage().is_connecting() {
186 self.request_restart().await;
187 return Ok(());
188 }
189 trace!("[ShardRunner {:?}] loop iteration reached the end.", self.shard.shard_info());
190 }
191 }
192
193 pub(super) fn runner_tx(&self) -> Sender<ShardRunnerMessage> {
195 self.runner_tx.clone()
196 }
197
198 #[instrument(skip(self, action))]
207 async fn action(&mut self, action: &ShardAction) -> Result<()> {
208 match *action {
209 ShardAction::Reconnect(ReconnectType::Reidentify) => {
210 self.request_restart().await;
211 Ok(())
212 },
213 ShardAction::Reconnect(ReconnectType::Resume) => self.shard.resume().await,
214 ShardAction::Heartbeat => self.shard.heartbeat().await,
215 ShardAction::Identify => self.shard.identify().await,
216 }
217 }
218
219 #[instrument(skip(self))]
226 async fn checked_shutdown(&mut self, id: ShardId, close_code: u16) -> bool {
227 if id != self.shard.shard_info().id {
229 return true;
231 }
232
233 drop(
235 self.shard
236 .client
237 .close(Some(CloseFrame {
238 code: close_code.into(),
239 reason: Cow::from(""),
240 }))
241 .await,
242 );
243
244 loop {
247 match self.shard.client.next().await {
248 Some(Ok(tungstenite::Message::Close(_))) => break,
249 Some(Err(_)) => {
250 warn!(
251 "[ShardRunner {:?}] Received an error awaiting close frame",
252 self.shard.shard_info(),
253 );
254 break;
255 },
256 _ => continue,
257 }
258 }
259
260 self.manager.shutdown_finished(id);
262 false
263 }
264
265 fn make_context(&self) -> Context {
266 Context::new(
267 Arc::clone(&self.data),
268 self,
269 self.shard.shard_info().id,
270 Arc::clone(&self.http),
271 #[cfg(feature = "cache")]
272 Arc::clone(&self.cache),
273 )
274 }
275
276 #[instrument(skip(self))]
283 async fn handle_rx_value(&mut self, msg: ShardRunnerMessage) -> bool {
284 match msg {
285 ShardRunnerMessage::Restart(id) => self.checked_shutdown(id, 4000).await,
286 ShardRunnerMessage::Shutdown(id, code) => self.checked_shutdown(id, code).await,
287 ShardRunnerMessage::ChunkGuild {
288 guild_id,
289 limit,
290 presences,
291 filter,
292 nonce,
293 } => self
294 .shard
295 .chunk_guild(guild_id, limit, presences, filter, nonce.as_deref())
296 .await
297 .is_ok(),
298 ShardRunnerMessage::Close(code, reason) => {
299 let reason = reason.unwrap_or_default();
300 let close = CloseFrame {
301 code: code.into(),
302 reason: Cow::from(reason),
303 };
304 self.shard.client.close(Some(close)).await.is_ok()
305 },
306 ShardRunnerMessage::Message(msg) => self.shard.client.send(msg).await.is_ok(),
307 ShardRunnerMessage::SetActivity(activity) => {
308 self.shard.set_activity(activity);
309 self.shard.update_presence().await.is_ok()
310 },
311 ShardRunnerMessage::SetPresence(activity, status) => {
312 self.shard.set_presence(activity, status);
313 self.shard.update_presence().await.is_ok()
314 },
315 ShardRunnerMessage::SetStatus(status) => {
316 self.shard.set_status(status);
317 self.shard.update_presence().await.is_ok()
318 },
319 }
320 }
321
322 #[cfg(feature = "voice")]
323 #[instrument(skip(self))]
324 async fn handle_voice_event(&self, event: &Event) {
325 if let Some(voice_manager) = &self.voice_manager {
326 match event {
327 Event::Ready(_) => {
328 voice_manager
329 .register_shard(self.shard.shard_info().id.0, self.runner_tx.clone())
330 .await;
331 },
332 Event::VoiceServerUpdate(event) => {
333 if let Some(guild_id) = event.guild_id {
334 voice_manager.server_update(guild_id, &event.endpoint, &event.token).await;
335 }
336 },
337 Event::VoiceStateUpdate(event) => {
338 if let Some(guild_id) = event.voice_state.guild_id {
339 voice_manager.state_update(guild_id, &event.voice_state).await;
340 }
341 },
342 _ => {},
343 }
344 }
345 }
346
347 #[instrument(skip(self))]
355 async fn recv(&mut self) -> bool {
356 loop {
357 match self.runner_rx.try_next() {
358 Ok(Some(value)) => {
359 if !self.handle_rx_value(value).await {
360 return false;
361 }
362 },
363 Ok(None) => {
364 warn!(
365 "[ShardRunner {:?}] Sending half DC; restarting",
366 self.shard.shard_info(),
367 );
368
369 self.request_restart().await;
370 return false;
371 },
372 Err(_) => break,
373 }
374 }
375
376 true
379 }
380
381 #[instrument(skip(self))]
384 async fn recv_event(&mut self) -> Result<(Option<Event>, Option<ShardAction>, bool)> {
385 let gw_event = match self.shard.client.recv_json().await {
386 Ok(inner) => Ok(inner),
387 Err(Error::Tungstenite(TungsteniteError::Io(_))) => {
388 debug!("Attempting to auto-reconnect");
389
390 match self.shard.reconnection_type() {
391 ReconnectType::Reidentify => return Ok((None, None, false)),
392 ReconnectType::Resume => {
393 if let Err(why) = self.shard.resume().await {
394 warn!("Failed to resume: {:?}", why);
395
396 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
398
399 return Ok((None, None, false));
400 }
401 },
402 }
403
404 return Ok((None, None, true));
405 },
406 Err(why) => Err(why),
407 };
408
409 let event = match gw_event {
410 Ok(Some(event)) => Ok(event),
411 Ok(None) => return Ok((None, None, true)),
412 Err(why) => Err(why),
413 };
414
415 let action = match self.shard.handle_event(&event) {
416 Ok(Some(action)) => Some(action),
417 Ok(None) => None,
418 Err(why) => {
419 error!("Shard handler received err: {:?}", why);
420
421 match &why {
422 Error::Gateway(
423 error @ (GatewayError::InvalidAuthentication
424 | GatewayError::InvalidGatewayIntents
425 | GatewayError::DisallowedGatewayIntents),
426 ) => {
427 self.manager.return_with_value(Err(error.clone())).await;
428
429 return Err(why);
430 },
431 _ => return Ok((None, None, true)),
432 }
433 },
434 };
435
436 if let Ok(GatewayEvent::HeartbeatAck) = event {
437 self.update_manager().await;
438 }
439
440 #[cfg(feature = "voice")]
441 {
442 if let Ok(GatewayEvent::Dispatch(_, ref event)) = event {
443 self.handle_voice_event(event).await;
444 }
445 }
446
447 let event = match event {
448 Ok(GatewayEvent::Dispatch(_, event)) => Some(event),
449 _ => None,
450 };
451
452 Ok((event, action, true))
453 }
454
455 #[instrument(skip(self))]
456 async fn request_restart(&mut self) {
457 debug!("[ShardRunner {:?}] Requesting restart", self.shard.shard_info());
458
459 self.update_manager().await;
460
461 let shard_id = self.shard.shard_info().id;
462 self.manager.restart_shard(shard_id).await;
463
464 #[cfg(feature = "voice")]
465 if let Some(voice_manager) = &self.voice_manager {
466 voice_manager.deregister_shard(shard_id.0).await;
467 }
468 }
469
470 #[instrument(skip(self))]
471 async fn update_manager(&self) {
472 self.manager
473 .update_shard_latency_and_stage(
474 self.shard.shard_info().id,
475 self.shard.latency(),
476 self.shard.stage(),
477 )
478 .await;
479 }
480}
481
482pub struct ShardRunnerOptions {
484 pub data: Arc<RwLock<TypeMap>>,
485 pub event_handlers: Vec<Arc<dyn EventHandler>>,
486 pub raw_event_handlers: Vec<Arc<dyn RawEventHandler>>,
487 #[cfg(feature = "framework")]
488 pub framework: Option<Arc<dyn Framework>>,
489 pub manager: Arc<ShardManager>,
490 pub shard: Shard,
491 #[cfg(feature = "voice")]
492 pub voice_manager: Option<Arc<dyn VoiceGatewayManager>>,
493 #[cfg(feature = "cache")]
494 pub cache: Arc<Cache>,
495 pub http: Arc<Http>,
496}