1use std::borrow::Cow;
2use std::sync::Arc;
3
4use futures::channel::mpsc::{self, UnboundedReceiver as Receiver, UnboundedSender as Sender};
5use tokio::sync::RwLock;
6use tokio_tungstenite::tungstenite;
7use tokio_tungstenite::tungstenite::error::Error as TungsteniteError;
8use tokio_tungstenite::tungstenite::protocol::frame::CloseFrame;
9use tracing::{debug, error, info, instrument, trace, warn};
10use typemap_rev::TypeMap;
11
12use super::event::ShardStageUpdateEvent;
13#[cfg(feature = "collector")]
14use super::CollectorCallback;
15#[cfg(feature = "voice")]
16use super::VoiceGatewayManager;
17use super::{ShardId, ShardManager, ShardRunnerMessage};
18#[cfg(feature = "cache")]
19use crate::cache::Cache;
20use crate::client::dispatch::dispatch_model;
21use crate::client::{Context, EventHandler, RawEventHandler};
22#[cfg(feature = "framework")]
23use crate::framework::Framework;
24use crate::gateway::{GatewayError, ReconnectType, Shard, ShardAction};
25use crate::http::Http;
26use crate::internal::prelude::*;
27use crate::internal::tokio::spawn_named;
28use crate::model::event::{Event, GatewayEvent};
29
30pub struct ShardRunner {
32 data: Arc<RwLock<TypeMap>>,
33 event_handlers: Vec<Arc<dyn EventHandler>>,
34 raw_event_handlers: Vec<Arc<dyn RawEventHandler>>,
35 #[cfg(feature = "framework")]
36 framework: Option<Arc<dyn Framework>>,
37 manager: Arc<ShardManager>,
38 runner_rx: Receiver<ShardRunnerMessage>,
40 runner_tx: Sender<ShardRunnerMessage>,
42 pub(crate) shard: Shard,
43 #[cfg(feature = "voice")]
44 voice_manager: Option<Arc<dyn VoiceGatewayManager + 'static>>,
45 #[cfg(feature = "cache")]
46 pub cache: Arc<Cache>,
47 pub http: Arc<Http>,
48 #[cfg(feature = "collector")]
49 pub(crate) collectors: Arc<std::sync::Mutex<Vec<CollectorCallback>>>,
50}
51
52impl ShardRunner {
53 pub fn new(opt: ShardRunnerOptions) -> Self {
55 let (tx, rx) = mpsc::unbounded();
56
57 Self {
58 runner_rx: rx,
59 runner_tx: tx,
60 data: opt.data,
61 event_handlers: opt.event_handlers,
62 raw_event_handlers: opt.raw_event_handlers,
63 #[cfg(feature = "framework")]
64 framework: opt.framework,
65 manager: opt.manager,
66 shard: opt.shard,
67 #[cfg(feature = "voice")]
68 voice_manager: opt.voice_manager,
69 #[cfg(feature = "cache")]
70 cache: opt.cache,
71 http: opt.http,
72 #[cfg(feature = "collector")]
73 collectors: Arc::new(std::sync::Mutex::new(vec![])),
74 }
75 }
76
77 #[instrument(skip(self))]
102 pub async fn run(&mut self) -> Result<()> {
103 info!("[ShardRunner {:?}] Running", self.shard.shard_info());
104
105 loop {
106 trace!("[ShardRunner {:?}] loop iteration started.", self.shard.shard_info());
107 if !self.recv().await {
108 return Ok(());
109 }
110
111 if !self.shard.do_heartbeat().await {
113 warn!("[ShardRunner {:?}] Error heartbeating", self.shard.shard_info(),);
114
115 self.request_restart().await;
116 return Ok(());
117 }
118
119 let pre = self.shard.stage();
120 let (event, action, successful) = self.recv_event().await?;
121 let post = self.shard.stage();
122
123 if post != pre {
124 self.update_manager().await;
125
126 for event_handler in self.event_handlers.clone() {
127 let context = self.make_context();
128 let event = ShardStageUpdateEvent {
129 new: post,
130 old: pre,
131 shard_id: self.shard.shard_info().id,
132 };
133 spawn_named("dispatch::event_handler::shard_stage_update", async move {
134 event_handler.shard_stage_update(context, event).await;
135 });
136 }
137 }
138
139 match action {
140 Some(ShardAction::Reconnect(ReconnectType::Reidentify)) => {
141 self.request_restart().await;
142 return Ok(());
143 },
144 Some(other) => {
145 if let Err(e) = self.action(&other).await {
146 debug!(
147 "[ShardRunner {:?}] Reconnecting due to error performing {:?}: {:?}",
148 self.shard.shard_info(),
149 other,
150 e
151 );
152 match self.shard.reconnection_type() {
153 ReconnectType::Reidentify => {
154 self.request_restart().await;
155 return Ok(());
156 },
157 ReconnectType::Resume => {
158 if let Err(why) = self.shard.resume().await {
159 warn!(
160 "[ShardRunner {:?}] Resume failed, reidentifying: {:?}",
161 self.shard.shard_info(),
162 why
163 );
164
165 self.request_restart().await;
166 return Ok(());
167 }
168 },
169 }
170 }
171 },
172 None => {},
173 }
174
175 if let Some(event) = event {
176 #[cfg(feature = "collector")]
177 self.collectors.lock().expect("poison").retain_mut(|callback| (callback.0)(&event));
178
179 dispatch_model(
180 event,
181 &self.make_context(),
182 #[cfg(feature = "framework")]
183 self.framework.clone(),
184 self.event_handlers.clone(),
185 self.raw_event_handlers.clone(),
186 );
187 }
188
189 if !successful && !self.shard.stage().is_connecting() {
190 self.request_restart().await;
191 return Ok(());
192 }
193 trace!("[ShardRunner {:?}] loop iteration reached the end.", self.shard.shard_info());
194 }
195 }
196
197 pub(super) fn runner_tx(&self) -> Sender<ShardRunnerMessage> {
199 self.runner_tx.clone()
200 }
201
202 #[instrument(skip(self, action))]
211 async fn action(&mut self, action: &ShardAction) -> Result<()> {
212 match *action {
213 ShardAction::Reconnect(ReconnectType::Reidentify) => {
214 self.request_restart().await;
215 Ok(())
216 },
217 ShardAction::Reconnect(ReconnectType::Resume) => self.shard.resume().await,
218 ShardAction::Heartbeat => self.shard.heartbeat().await,
219 ShardAction::Identify => self.shard.identify().await,
220 }
221 }
222
223 #[instrument(skip(self))]
230 async fn checked_shutdown(&mut self, id: ShardId, close_code: u16) -> bool {
231 if id != self.shard.shard_info().id {
233 return true;
235 }
236
237 drop(
239 self.shard
240 .client
241 .close(Some(CloseFrame {
242 code: close_code.into(),
243 reason: Cow::from(""),
244 }))
245 .await,
246 );
247
248 loop {
251 match self.shard.client.next().await {
252 Some(Ok(tungstenite::Message::Close(_))) => break,
253 Some(Err(_)) => {
254 warn!(
255 "[ShardRunner {:?}] Received an error awaiting close frame",
256 self.shard.shard_info(),
257 );
258 break;
259 },
260 _ => {},
261 }
262 }
263
264 self.manager.shutdown_finished(id);
266 false
267 }
268
269 fn make_context(&self) -> Context {
270 Context::new(
271 Arc::clone(&self.data),
272 self,
273 self.shard.shard_info().id,
274 Arc::clone(&self.http),
275 #[cfg(feature = "cache")]
276 Arc::clone(&self.cache),
277 )
278 }
279
280 #[instrument(skip(self))]
287 async fn handle_rx_value(&mut self, msg: ShardRunnerMessage) -> bool {
288 match msg {
289 ShardRunnerMessage::Restart(id) => self.checked_shutdown(id, 4000).await,
290 ShardRunnerMessage::Shutdown(id, code) => self.checked_shutdown(id, code).await,
291 ShardRunnerMessage::ChunkGuild {
292 guild_id,
293 limit,
294 presences,
295 filter,
296 nonce,
297 } => self
298 .shard
299 .chunk_guild(guild_id, limit, presences, filter, nonce.as_deref())
300 .await
301 .is_ok(),
302 ShardRunnerMessage::SoundboardSounds {
303 guild_ids,
304 } => self.shard.request_soundboard_sounds(&guild_ids).await.is_ok(),
305 ShardRunnerMessage::Close(code, reason) => {
306 let reason = reason.unwrap_or_default();
307 let close = CloseFrame {
308 code: code.into(),
309 reason: Cow::from(reason),
310 };
311 self.shard.client.close(Some(close)).await.is_ok()
312 },
313 ShardRunnerMessage::Message(msg) => self.shard.client.send(msg).await.is_ok(),
314 ShardRunnerMessage::SetActivity(activity) => {
315 self.shard.set_activity(activity);
316 self.shard.update_presence().await.is_ok()
317 },
318 ShardRunnerMessage::SetPresence(activity, status) => {
319 self.shard.set_presence(activity, status);
320 self.shard.update_presence().await.is_ok()
321 },
322 ShardRunnerMessage::SetStatus(status) => {
323 self.shard.set_status(status);
324 self.shard.update_presence().await.is_ok()
325 },
326 }
327 }
328
329 #[cfg(feature = "voice")]
330 #[instrument(skip(self))]
331 async fn handle_voice_event(&self, event: &Event) {
332 if let Some(voice_manager) = &self.voice_manager {
333 match event {
334 Event::Ready(_) => {
335 voice_manager
336 .register_shard(self.shard.shard_info().id.0, self.runner_tx.clone())
337 .await;
338 },
339 Event::VoiceServerUpdate(event) => {
340 if let Some(guild_id) = event.guild_id {
341 voice_manager.server_update(guild_id, &event.endpoint, &event.token).await;
342 }
343 },
344 Event::VoiceStateUpdate(event) => {
345 if let Some(guild_id) = event.voice_state.guild_id {
346 voice_manager.state_update(guild_id, &event.voice_state).await;
347 }
348 },
349 _ => {},
350 }
351 }
352 }
353
354 #[instrument(skip(self))]
362 async fn recv(&mut self) -> bool {
363 loop {
364 match self.runner_rx.try_next() {
365 Ok(Some(value)) => {
366 if !self.handle_rx_value(value).await {
367 return false;
368 }
369 },
370 Ok(None) => {
371 warn!(
372 "[ShardRunner {:?}] Sending half DC; restarting",
373 self.shard.shard_info(),
374 );
375
376 self.request_restart().await;
377 return false;
378 },
379 Err(_) => break,
380 }
381 }
382
383 true
386 }
387
388 #[instrument(skip(self))]
391 async fn recv_event(&mut self) -> Result<(Option<Event>, Option<ShardAction>, bool)> {
392 let gw_event = match self.shard.client.recv_json().await {
393 Ok(inner) => Ok(inner),
394 Err(Error::Tungstenite(TungsteniteError::Io(_))) => {
395 debug!("Attempting to auto-reconnect");
396
397 match self.shard.reconnection_type() {
398 ReconnectType::Reidentify => return Ok((None, None, false)),
399 ReconnectType::Resume => {
400 if let Err(why) = self.shard.resume().await {
401 warn!("Failed to resume: {:?}", why);
402
403 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
405
406 return Ok((None, None, false));
407 }
408 },
409 }
410
411 return Ok((None, None, true));
412 },
413 Err(why) => Err(why),
414 };
415
416 let event = match gw_event {
417 Ok(Some(event)) => Ok(event),
418 Ok(None) => return Ok((None, None, true)),
419 Err(why) => Err(why),
420 };
421
422 let action = match self.shard.handle_event(&event) {
423 Ok(Some(action)) => Some(action),
424 Ok(None) => None,
425 Err(why) => {
426 error!("Shard handler received err: {:?}", why);
427
428 match &why {
429 Error::Gateway(
430 error @ (GatewayError::InvalidAuthentication
431 | GatewayError::InvalidGatewayIntents
432 | GatewayError::DisallowedGatewayIntents),
433 ) => {
434 self.manager.return_with_value(Err(error.clone())).await;
435
436 return Err(why);
437 },
438 _ => return Ok((None, None, true)),
439 }
440 },
441 };
442
443 if let Ok(GatewayEvent::HeartbeatAck) = event {
444 self.update_manager().await;
445 }
446
447 #[cfg(feature = "voice")]
448 {
449 if let Ok(GatewayEvent::Dispatch(_, ref event)) = event {
450 self.handle_voice_event(event).await;
451 }
452 }
453
454 let event = match event {
455 Ok(GatewayEvent::Dispatch(_, event)) => Some(event),
456 _ => None,
457 };
458
459 Ok((event, action, true))
460 }
461
462 #[instrument(skip(self))]
463 async fn request_restart(&mut self) {
464 debug!("[ShardRunner {:?}] Requesting restart", self.shard.shard_info());
465
466 self.update_manager().await;
467
468 let shard_id = self.shard.shard_info().id;
469 self.manager.restart_shard(shard_id).await;
470
471 #[cfg(feature = "voice")]
472 if let Some(voice_manager) = &self.voice_manager {
473 voice_manager.deregister_shard(shard_id.0).await;
474 }
475 }
476
477 #[instrument(skip(self))]
478 async fn update_manager(&self) {
479 self.manager
480 .update_shard_latency_and_stage(
481 self.shard.shard_info().id,
482 self.shard.latency(),
483 self.shard.stage(),
484 )
485 .await;
486 }
487}
488
489pub struct ShardRunnerOptions {
491 pub data: Arc<RwLock<TypeMap>>,
492 pub event_handlers: Vec<Arc<dyn EventHandler>>,
493 pub raw_event_handlers: Vec<Arc<dyn RawEventHandler>>,
494 #[cfg(feature = "framework")]
495 pub framework: Option<Arc<dyn Framework>>,
496 pub manager: Arc<ShardManager>,
497 pub shard: Shard,
498 #[cfg(feature = "voice")]
499 pub voice_manager: Option<Arc<dyn VoiceGatewayManager>>,
500 #[cfg(feature = "cache")]
501 pub cache: Arc<Cache>,
502 pub http: Arc<Http>,
503}