1use super::one::RefMut;
2use crate::lock::RwLockWriteGuard;
3use crate::util;
4use crate::util::SharedValue;
5use crate::HashMap;
6use core::hash::{BuildHasher, Hash};
7use core::mem;
8use core::ptr;
9use std::collections::hash_map::RandomState;
10
11pub enum Entry<'a, K, V, S = RandomState> {
12 Occupied(OccupiedEntry<'a, K, V, S>),
13 Vacant(VacantEntry<'a, K, V, S>),
14}
15
16impl<'a, K: Eq + Hash, V, S: BuildHasher> Entry<'a, K, V, S> {
17 pub fn and_modify(self, f: impl FnOnce(&mut V)) -> Self {
19 match self {
20 Entry::Occupied(mut entry) => {
21 f(entry.get_mut());
22
23 Entry::Occupied(entry)
24 }
25
26 Entry::Vacant(entry) => Entry::Vacant(entry),
27 }
28 }
29
30 pub fn key(&self) -> &K {
32 match *self {
33 Entry::Occupied(ref entry) => entry.key(),
34 Entry::Vacant(ref entry) => entry.key(),
35 }
36 }
37
38 pub fn into_key(self) -> K {
40 match self {
41 Entry::Occupied(entry) => entry.into_key(),
42 Entry::Vacant(entry) => entry.into_key(),
43 }
44 }
45
46 pub fn or_default(self) -> RefMut<'a, K, V, S>
49 where
50 V: Default,
51 {
52 match self {
53 Entry::Occupied(entry) => entry.into_ref(),
54 Entry::Vacant(entry) => entry.insert(V::default()),
55 }
56 }
57
58 pub fn or_insert(self, value: V) -> RefMut<'a, K, V, S> {
61 match self {
62 Entry::Occupied(entry) => entry.into_ref(),
63 Entry::Vacant(entry) => entry.insert(value),
64 }
65 }
66
67 pub fn or_insert_with(self, value: impl FnOnce() -> V) -> RefMut<'a, K, V, S> {
70 match self {
71 Entry::Occupied(entry) => entry.into_ref(),
72 Entry::Vacant(entry) => entry.insert(value()),
73 }
74 }
75
76 pub fn or_try_insert_with<E>(
77 self,
78 value: impl FnOnce() -> Result<V, E>,
79 ) -> Result<RefMut<'a, K, V, S>, E> {
80 match self {
81 Entry::Occupied(entry) => Ok(entry.into_ref()),
82 Entry::Vacant(entry) => Ok(entry.insert(value()?)),
83 }
84 }
85
86 pub fn insert(self, value: V) -> RefMut<'a, K, V, S> {
88 match self {
89 Entry::Occupied(mut entry) => {
90 entry.insert(value);
91 entry.into_ref()
92 }
93 Entry::Vacant(entry) => entry.insert(value),
94 }
95 }
96
97 pub fn insert_entry(self, value: V) -> OccupiedEntry<'a, K, V, S>
104 where
105 K: Clone,
106 {
107 match self {
108 Entry::Occupied(mut entry) => {
109 entry.insert(value);
110 entry
111 }
112 Entry::Vacant(entry) => entry.insert_entry(value),
113 }
114 }
115}
116
117pub struct VacantEntry<'a, K, V, S = RandomState> {
118 shard: RwLockWriteGuard<'a, HashMap<K, V, S>>,
119 key: K,
120}
121
122unsafe impl<'a, K: Eq + Hash + Sync, V: Sync, S: BuildHasher> Send for VacantEntry<'a, K, V, S> {}
123unsafe impl<'a, K: Eq + Hash + Sync, V: Sync, S: BuildHasher> Sync for VacantEntry<'a, K, V, S> {}
124
125impl<'a, K: Eq + Hash, V, S: BuildHasher> VacantEntry<'a, K, V, S> {
126 pub(crate) unsafe fn new(shard: RwLockWriteGuard<'a, HashMap<K, V, S>>, key: K) -> Self {
127 Self { shard, key }
128 }
129
130 pub fn insert(mut self, value: V) -> RefMut<'a, K, V, S> {
131 unsafe {
132 let c: K = ptr::read(&self.key);
133
134 self.shard.insert(self.key, SharedValue::new(value));
135
136 let (k, v) = self.shard.get_key_value(&c).unwrap();
137
138 let k = util::change_lifetime_const(k);
139
140 let v = &mut *v.as_ptr();
141
142 let r = RefMut::new(self.shard, k, v);
143
144 mem::forget(c);
145
146 r
147 }
148 }
149
150 pub fn insert_entry(mut self, value: V) -> OccupiedEntry<'a, K, V, S>
152 where
153 K: Clone,
154 {
155 unsafe {
156 self.shard.insert(self.key.clone(), SharedValue::new(value));
157
158 let (k, v) = self.shard.get_key_value(&self.key).unwrap();
159
160 let kptr: *const K = k;
161 let vptr: *mut V = v.as_ptr();
162 OccupiedEntry::new(self.shard, self.key, (kptr, vptr))
163 }
164 }
165
166 pub fn into_key(self) -> K {
167 self.key
168 }
169
170 pub fn key(&self) -> &K {
171 &self.key
172 }
173}
174
175pub struct OccupiedEntry<'a, K, V, S = RandomState> {
176 shard: RwLockWriteGuard<'a, HashMap<K, V, S>>,
177 elem: (*const K, *mut V),
178 key: K,
179}
180
181unsafe impl<'a, K: Eq + Hash + Sync, V: Sync, S: BuildHasher> Send for OccupiedEntry<'a, K, V, S> {}
182unsafe impl<'a, K: Eq + Hash + Sync, V: Sync, S: BuildHasher> Sync for OccupiedEntry<'a, K, V, S> {}
183
184impl<'a, K: Eq + Hash, V, S: BuildHasher> OccupiedEntry<'a, K, V, S> {
185 pub(crate) unsafe fn new(
186 shard: RwLockWriteGuard<'a, HashMap<K, V, S>>,
187 key: K,
188 elem: (*const K, *mut V),
189 ) -> Self {
190 Self { shard, elem, key }
191 }
192
193 pub fn get(&self) -> &V {
194 unsafe { &*self.elem.1 }
195 }
196
197 pub fn get_mut(&mut self) -> &mut V {
198 unsafe { &mut *self.elem.1 }
199 }
200
201 pub fn insert(&mut self, value: V) -> V {
202 mem::replace(self.get_mut(), value)
203 }
204
205 pub fn into_ref(self) -> RefMut<'a, K, V, S> {
206 unsafe { RefMut::new(self.shard, self.elem.0, self.elem.1) }
207 }
208
209 pub fn into_key(self) -> K {
210 self.key
211 }
212
213 pub fn key(&self) -> &K {
214 unsafe { &*self.elem.0 }
215 }
216
217 pub fn remove(mut self) -> V {
218 let key = unsafe { &*self.elem.0 };
219 self.shard.remove(key).unwrap().into_inner()
220 }
221
222 pub fn remove_entry(mut self) -> (K, V) {
223 let key = unsafe { &*self.elem.0 };
224 let (k, v) = self.shard.remove_entry(key).unwrap();
225 (k, v.into_inner())
226 }
227
228 pub fn replace_entry(mut self, value: V) -> (K, V) {
229 let nk = self.key;
230 let key = unsafe { &*self.elem.0 };
231 let (k, v) = self.shard.remove_entry(key).unwrap();
232 self.shard.insert(nk, SharedValue::new(value));
233 (k, v.into_inner())
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use crate::DashMap;
240
241 use super::*;
242
243 #[test]
244 fn test_insert_entry_into_vacant() {
245 let map: DashMap<u32, u32> = DashMap::new();
246
247 let entry = map.entry(1);
248
249 assert!(matches!(entry, Entry::Vacant(_)));
250
251 let entry = entry.insert_entry(2);
252
253 assert_eq!(*entry.get(), 2);
254
255 drop(entry);
256
257 assert_eq!(*map.get(&1).unwrap(), 2);
258 }
259
260 #[test]
261 fn test_insert_entry_into_occupied() {
262 let map: DashMap<u32, u32> = DashMap::new();
263
264 map.insert(1, 1000);
265
266 let entry = map.entry(1);
267
268 assert!(matches!(&entry, Entry::Occupied(entry) if *entry.get() == 1000));
269
270 let entry = entry.insert_entry(2);
271
272 assert_eq!(*entry.get(), 2);
273
274 drop(entry);
275
276 assert_eq!(*map.get(&1).unwrap(), 2);
277 }
278}