1#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))]
2use crate::loom::cell::UnsafeCell;
19use crate::loom::sync::atomic::AtomicUsize;
20use crate::loom::sync::{Mutex, MutexGuard};
21use crate::util::linked_list::{self, LinkedList};
22#[cfg(all(tokio_unstable, feature = "tracing"))]
23use crate::util::trace;
24use crate::util::WakeList;
25
26use std::future::Future;
27use std::marker::PhantomPinned;
28use std::pin::Pin;
29use std::ptr::NonNull;
30use std::sync::atomic::Ordering::*;
31use std::task::{ready, Context, Poll, Waker};
32use std::{cmp, fmt};
33
34pub(crate) struct Semaphore {
36    waiters: Mutex<Waitlist>,
37    permits: AtomicUsize,
39    #[cfg(all(tokio_unstable, feature = "tracing"))]
40    resource_span: tracing::Span,
41}
42
43struct Waitlist {
44    queue: LinkedList<Waiter, <Waiter as linked_list::Link>::Target>,
45    closed: bool,
46}
47
48#[derive(Debug, PartialEq, Eq)]
52pub enum TryAcquireError {
53    Closed,
57
58    NoPermits,
60}
61#[derive(Debug)]
69pub struct AcquireError(());
70
71pub(crate) struct Acquire<'a> {
72    node: Waiter,
73    semaphore: &'a Semaphore,
74    num_permits: usize,
75    queued: bool,
76}
77
78struct Waiter {
80    state: AtomicUsize,
85
86    waker: UnsafeCell<Option<Waker>>,
92
93    pointers: linked_list::Pointers<Waiter>,
106
107    #[cfg(all(tokio_unstable, feature = "tracing"))]
108    ctx: trace::AsyncOpTracingCtx,
109
110    _p: PhantomPinned,
112}
113
114generate_addr_of_methods! {
115    impl<> Waiter {
116        unsafe fn addr_of_pointers(self: NonNull<Self>) -> NonNull<linked_list::Pointers<Waiter>> {
117            &self.pointers
118        }
119    }
120}
121
122impl Semaphore {
123    pub(crate) const MAX_PERMITS: usize = usize::MAX >> 3;
131    const CLOSED: usize = 1;
132    const PERMIT_SHIFT: usize = 1;
136
137    pub(crate) fn new(permits: usize) -> Self {
141        assert!(
142            permits <= Self::MAX_PERMITS,
143            "a semaphore may not have more than MAX_PERMITS permits ({})",
144            Self::MAX_PERMITS
145        );
146
147        #[cfg(all(tokio_unstable, feature = "tracing"))]
148        let resource_span = {
149            let resource_span = tracing::trace_span!(
150                parent: None,
151                "runtime.resource",
152                concrete_type = "Semaphore",
153                kind = "Sync",
154                is_internal = true
155            );
156
157            resource_span.in_scope(|| {
158                tracing::trace!(
159                    target: "runtime::resource::state_update",
160                    permits = permits,
161                    permits.op = "override",
162                )
163            });
164            resource_span
165        };
166
167        Self {
168            permits: AtomicUsize::new(permits << Self::PERMIT_SHIFT),
169            waiters: Mutex::new(Waitlist {
170                queue: LinkedList::new(),
171                closed: false,
172            }),
173            #[cfg(all(tokio_unstable, feature = "tracing"))]
174            resource_span,
175        }
176    }
177
178    #[cfg(not(all(loom, test)))]
182    pub(crate) const fn const_new(permits: usize) -> Self {
183        assert!(permits <= Self::MAX_PERMITS);
184
185        Self {
186            permits: AtomicUsize::new(permits << Self::PERMIT_SHIFT),
187            waiters: Mutex::const_new(Waitlist {
188                queue: LinkedList::new(),
189                closed: false,
190            }),
191            #[cfg(all(tokio_unstable, feature = "tracing"))]
192            resource_span: tracing::Span::none(),
193        }
194    }
195
196    pub(crate) fn new_closed() -> Self {
198        Self {
199            permits: AtomicUsize::new(Self::CLOSED),
200            waiters: Mutex::new(Waitlist {
201                queue: LinkedList::new(),
202                closed: true,
203            }),
204            #[cfg(all(tokio_unstable, feature = "tracing"))]
205            resource_span: tracing::Span::none(),
206        }
207    }
208
209    #[cfg(not(all(loom, test)))]
211    pub(crate) const fn const_new_closed() -> Self {
212        Self {
213            permits: AtomicUsize::new(Self::CLOSED),
214            waiters: Mutex::const_new(Waitlist {
215                queue: LinkedList::new(),
216                closed: true,
217            }),
218            #[cfg(all(tokio_unstable, feature = "tracing"))]
219            resource_span: tracing::Span::none(),
220        }
221    }
222
223    pub(crate) fn available_permits(&self) -> usize {
225        self.permits.load(Acquire) >> Self::PERMIT_SHIFT
226    }
227
228    pub(crate) fn release(&self, added: usize) {
232        if added == 0 {
233            return;
234        }
235
236        self.add_permits_locked(added, self.waiters.lock());
238    }
239
240    pub(crate) fn close(&self) {
243        let mut waiters = self.waiters.lock();
244        self.permits.fetch_or(Self::CLOSED, Release);
252        waiters.closed = true;
253        while let Some(mut waiter) = waiters.queue.pop_back() {
254            let waker = unsafe { waiter.as_mut().waker.with_mut(|waker| (*waker).take()) };
255            if let Some(waker) = waker {
256                waker.wake();
257            }
258        }
259    }
260
261    pub(crate) fn is_closed(&self) -> bool {
263        self.permits.load(Acquire) & Self::CLOSED == Self::CLOSED
264    }
265
266    pub(crate) fn try_acquire(&self, num_permits: usize) -> Result<(), TryAcquireError> {
267        assert!(
268            num_permits <= Self::MAX_PERMITS,
269            "a semaphore may not have more than MAX_PERMITS permits ({})",
270            Self::MAX_PERMITS
271        );
272        let num_permits = num_permits << Self::PERMIT_SHIFT;
273        let mut curr = self.permits.load(Acquire);
274        loop {
275            if curr & Self::CLOSED == Self::CLOSED {
277                return Err(TryAcquireError::Closed);
278            }
279
280            if curr < num_permits {
282                return Err(TryAcquireError::NoPermits);
283            }
284
285            let next = curr - num_permits;
286
287            match self.permits.compare_exchange(curr, next, AcqRel, Acquire) {
288                Ok(_) => {
289                    return Ok(());
291                }
292                Err(actual) => curr = actual,
293            }
294        }
295    }
296
297    pub(crate) fn acquire(&self, num_permits: usize) -> Acquire<'_> {
298        Acquire::new(self, num_permits)
299    }
300
301    fn add_permits_locked(&self, mut rem: usize, waiters: MutexGuard<'_, Waitlist>) {
307        let mut wakers = WakeList::new();
308        let mut lock = Some(waiters);
309        let mut is_empty = false;
310        while rem > 0 {
311            let mut waiters = lock.take().unwrap_or_else(|| self.waiters.lock());
312            'inner: while wakers.can_push() {
313                match waiters.queue.last() {
315                    Some(waiter) => {
316                        if !waiter.assign_permits(&mut rem) {
317                            break 'inner;
318                        }
319                    }
320                    None => {
321                        is_empty = true;
322                        break 'inner;
325                    }
326                };
327                let mut waiter = waiters.queue.pop_back().unwrap();
328                if let Some(waker) =
329                    unsafe { waiter.as_mut().waker.with_mut(|waker| (*waker).take()) }
330                {
331                    wakers.push(waker);
332                }
333            }
334
335            if rem > 0 && is_empty {
336                let permits = rem;
337                assert!(
338                    permits <= Self::MAX_PERMITS,
339                    "cannot add more than MAX_PERMITS permits ({})",
340                    Self::MAX_PERMITS
341                );
342                let prev = self.permits.fetch_add(rem << Self::PERMIT_SHIFT, Release);
343                let prev = prev >> Self::PERMIT_SHIFT;
344                assert!(
345                    prev + permits <= Self::MAX_PERMITS,
346                    "number of added permits ({}) would overflow MAX_PERMITS ({})",
347                    rem,
348                    Self::MAX_PERMITS
349                );
350
351                #[cfg(all(tokio_unstable, feature = "tracing"))]
353                self.resource_span.in_scope(|| {
354                    tracing::trace!(
355                    target: "runtime::resource::state_update",
356                    permits = rem,
357                    permits.op = "add",
358                    )
359                });
360
361                rem = 0;
362            }
363
364            drop(waiters); wakers.wake_all();
367        }
368
369        assert_eq!(rem, 0);
370    }
371
372    pub(crate) fn forget_permits(&self, n: usize) -> usize {
377        if n == 0 {
378            return 0;
379        }
380
381        let mut curr_bits = self.permits.load(Acquire);
382        loop {
383            let curr = curr_bits >> Self::PERMIT_SHIFT;
384            let new = curr.saturating_sub(n);
385            match self.permits.compare_exchange_weak(
386                curr_bits,
387                new << Self::PERMIT_SHIFT,
388                AcqRel,
389                Acquire,
390            ) {
391                Ok(_) => return std::cmp::min(curr, n),
392                Err(actual) => curr_bits = actual,
393            };
394        }
395    }
396
397    fn poll_acquire(
398        &self,
399        cx: &mut Context<'_>,
400        num_permits: usize,
401        node: Pin<&mut Waiter>,
402        queued: bool,
403    ) -> Poll<Result<(), AcquireError>> {
404        let mut acquired = 0;
405
406        let needed = if queued {
407            node.state.load(Acquire) << Self::PERMIT_SHIFT
408        } else {
409            num_permits << Self::PERMIT_SHIFT
410        };
411
412        let mut lock = None;
413        let mut curr = self.permits.load(Acquire);
416        let mut waiters = loop {
417            if curr & Self::CLOSED > 0 {
419                return Poll::Ready(Err(AcquireError::closed()));
420            }
421
422            let mut remaining = 0;
423            let total = curr
424                .checked_add(acquired)
425                .expect("number of permits must not overflow");
426            let (next, acq) = if total >= needed {
427                let next = curr - (needed - acquired);
428                (next, needed >> Self::PERMIT_SHIFT)
429            } else {
430                remaining = (needed - acquired) - curr;
431                (0, curr >> Self::PERMIT_SHIFT)
432            };
433
434            if remaining > 0 && lock.is_none() {
435                lock = Some(self.waiters.lock());
443            }
444
445            match self.permits.compare_exchange(curr, next, AcqRel, Acquire) {
446                Ok(_) => {
447                    acquired += acq;
448                    if remaining == 0 {
449                        if !queued {
450                            #[cfg(all(tokio_unstable, feature = "tracing"))]
451                            self.resource_span.in_scope(|| {
452                                tracing::trace!(
453                                    target: "runtime::resource::state_update",
454                                    permits = acquired,
455                                    permits.op = "sub",
456                                );
457                                tracing::trace!(
458                                    target: "runtime::resource::async_op::state_update",
459                                    permits_obtained = acquired,
460                                    permits.op = "add",
461                                )
462                            });
463
464                            return Poll::Ready(Ok(()));
465                        } else if lock.is_none() {
466                            break self.waiters.lock();
467                        }
468                    }
469                    break lock.expect("lock must be acquired before waiting");
470                }
471                Err(actual) => curr = actual,
472            }
473        };
474
475        if waiters.closed {
476            return Poll::Ready(Err(AcquireError::closed()));
477        }
478
479        #[cfg(all(tokio_unstable, feature = "tracing"))]
480        self.resource_span.in_scope(|| {
481            tracing::trace!(
482                target: "runtime::resource::state_update",
483                permits = acquired,
484                permits.op = "sub",
485            )
486        });
487
488        if node.assign_permits(&mut acquired) {
489            self.add_permits_locked(acquired, waiters);
490            return Poll::Ready(Ok(()));
491        }
492
493        assert_eq!(acquired, 0);
494        let mut old_waker = None;
495
496        node.waker.with_mut(|waker| {
498            let waker = unsafe { &mut *waker };
500            if waker
502                .as_ref()
503                .map_or(true, |waker| !waker.will_wake(cx.waker()))
504            {
505                old_waker = std::mem::replace(waker, Some(cx.waker().clone()));
506            }
507        });
508
509        if !queued {
511            let node = unsafe {
512                let node = Pin::into_inner_unchecked(node) as *mut _;
513                NonNull::new_unchecked(node)
514            };
515
516            waiters.queue.push_front(node);
517        }
518        drop(waiters);
519        drop(old_waker);
520
521        Poll::Pending
522    }
523}
524
525impl fmt::Debug for Semaphore {
526    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
527        fmt.debug_struct("Semaphore")
528            .field("permits", &self.available_permits())
529            .finish()
530    }
531}
532
533impl Waiter {
534    fn new(
535        num_permits: usize,
536        #[cfg(all(tokio_unstable, feature = "tracing"))] ctx: trace::AsyncOpTracingCtx,
537    ) -> Self {
538        Waiter {
539            waker: UnsafeCell::new(None),
540            state: AtomicUsize::new(num_permits),
541            pointers: linked_list::Pointers::new(),
542            #[cfg(all(tokio_unstable, feature = "tracing"))]
543            ctx,
544            _p: PhantomPinned,
545        }
546    }
547
548    fn assign_permits(&self, n: &mut usize) -> bool {
552        let mut curr = self.state.load(Acquire);
553        loop {
554            let assign = cmp::min(curr, *n);
555            let next = curr - assign;
556            match self.state.compare_exchange(curr, next, AcqRel, Acquire) {
557                Ok(_) => {
558                    *n -= assign;
559                    #[cfg(all(tokio_unstable, feature = "tracing"))]
560                    self.ctx.async_op_span.in_scope(|| {
561                        tracing::trace!(
562                            target: "runtime::resource::async_op::state_update",
563                            permits_obtained = assign,
564                            permits.op = "add",
565                        );
566                    });
567                    return next == 0;
568                }
569                Err(actual) => curr = actual,
570            }
571        }
572    }
573}
574
575impl Future for Acquire<'_> {
576    type Output = Result<(), AcquireError>;
577
578    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
579        ready!(crate::trace::trace_leaf(cx));
580
581        #[cfg(all(tokio_unstable, feature = "tracing"))]
582        let _resource_span = self.node.ctx.resource_span.clone().entered();
583        #[cfg(all(tokio_unstable, feature = "tracing"))]
584        let _async_op_span = self.node.ctx.async_op_span.clone().entered();
585        #[cfg(all(tokio_unstable, feature = "tracing"))]
586        let _async_op_poll_span = self.node.ctx.async_op_poll_span.clone().entered();
587
588        let (node, semaphore, needed, queued) = self.project();
589
590        #[cfg(all(tokio_unstable, feature = "tracing"))]
592        let coop = ready!(trace_poll_op!(
593            "poll_acquire",
594            crate::runtime::coop::poll_proceed(cx),
595        ));
596
597        #[cfg(not(all(tokio_unstable, feature = "tracing")))]
598        let coop = ready!(crate::runtime::coop::poll_proceed(cx));
599
600        let result = match semaphore.poll_acquire(cx, needed, node, *queued) {
601            Poll::Pending => {
602                *queued = true;
603                Poll::Pending
604            }
605            Poll::Ready(r) => {
606                coop.made_progress();
607                r?;
608                *queued = false;
609                Poll::Ready(Ok(()))
610            }
611        };
612
613        #[cfg(all(tokio_unstable, feature = "tracing"))]
614        return trace_poll_op!("poll_acquire", result);
615
616        #[cfg(not(all(tokio_unstable, feature = "tracing")))]
617        return result;
618    }
619}
620
621impl<'a> Acquire<'a> {
622    fn new(semaphore: &'a Semaphore, num_permits: usize) -> Self {
623        #[cfg(any(not(tokio_unstable), not(feature = "tracing")))]
624        return Self {
625            node: Waiter::new(num_permits),
626            semaphore,
627            num_permits,
628            queued: false,
629        };
630
631        #[cfg(all(tokio_unstable, feature = "tracing"))]
632        return semaphore.resource_span.in_scope(|| {
633            let async_op_span =
634                tracing::trace_span!("runtime.resource.async_op", source = "Acquire::new");
635            let async_op_poll_span = async_op_span.in_scope(|| {
636                tracing::trace!(
637                    target: "runtime::resource::async_op::state_update",
638                    permits_requested = num_permits,
639                    permits.op = "override",
640                );
641
642                tracing::trace!(
643                    target: "runtime::resource::async_op::state_update",
644                    permits_obtained = 0usize,
645                    permits.op = "override",
646                );
647
648                tracing::trace_span!("runtime.resource.async_op.poll")
649            });
650
651            let ctx = trace::AsyncOpTracingCtx {
652                async_op_span,
653                async_op_poll_span,
654                resource_span: semaphore.resource_span.clone(),
655            };
656
657            Self {
658                node: Waiter::new(num_permits, ctx),
659                semaphore,
660                num_permits,
661                queued: false,
662            }
663        });
664    }
665
666    fn project(self: Pin<&mut Self>) -> (Pin<&mut Waiter>, &Semaphore, usize, &mut bool) {
667        fn is_unpin<T: Unpin>() {}
668        unsafe {
669            is_unpin::<&Semaphore>();
672            is_unpin::<&mut bool>();
673            is_unpin::<usize>();
674
675            let this = self.get_unchecked_mut();
676            (
677                Pin::new_unchecked(&mut this.node),
678                this.semaphore,
679                this.num_permits,
680                &mut this.queued,
681            )
682        }
683    }
684}
685
686impl Drop for Acquire<'_> {
687    fn drop(&mut self) {
688        if !self.queued {
691            return;
692        }
693
694        let mut waiters = self.semaphore.waiters.lock();
698
699        let node = NonNull::from(&mut self.node);
701        unsafe { waiters.queue.remove(node) };
703
704        let acquired_permits = self.num_permits - self.node.state.load(Acquire);
705        if acquired_permits > 0 {
706            self.semaphore.add_permits_locked(acquired_permits, waiters);
707        }
708    }
709}
710
711unsafe impl Sync for Acquire<'_> {}
717
718impl AcquireError {
721    fn closed() -> AcquireError {
722        AcquireError(())
723    }
724}
725
726impl fmt::Display for AcquireError {
727    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
728        write!(fmt, "semaphore closed")
729    }
730}
731
732impl std::error::Error for AcquireError {}
733
734impl TryAcquireError {
737    #[allow(dead_code)] pub(crate) fn is_closed(&self) -> bool {
740        matches!(self, TryAcquireError::Closed)
741    }
742
743    #[allow(dead_code)] pub(crate) fn is_no_permits(&self) -> bool {
747        matches!(self, TryAcquireError::NoPermits)
748    }
749}
750
751impl fmt::Display for TryAcquireError {
752    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
753        match self {
754            TryAcquireError::Closed => write!(fmt, "semaphore closed"),
755            TryAcquireError::NoPermits => write!(fmt, "no permits available"),
756        }
757    }
758}
759
760impl std::error::Error for TryAcquireError {}
761
762unsafe impl linked_list::Link for Waiter {
766    type Handle = NonNull<Waiter>;
767    type Target = Waiter;
768
769    fn as_raw(handle: &Self::Handle) -> NonNull<Waiter> {
770        *handle
771    }
772
773    unsafe fn from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter> {
774        ptr
775    }
776
777    unsafe fn pointers(target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> {
778        Waiter::addr_of_pointers(target)
779    }
780}