proc_macro/bridge/
client.rs

1//! Client-side types.
2
3use std::cell::RefCell;
4use std::marker::PhantomData;
5use std::sync::atomic::AtomicU32;
6
7use super::*;
8
9macro_rules! define_client_handles {
10    (
11        'owned: $($oty:ident,)*
12        'interned: $($ity:ident,)*
13    ) => {
14        #[repr(C)]
15        #[allow(non_snake_case)]
16        pub(super) struct HandleCounters {
17            $(pub(super) $oty: AtomicU32,)*
18            $(pub(super) $ity: AtomicU32,)*
19        }
20
21        static COUNTERS: HandleCounters = HandleCounters {
22            $($oty: AtomicU32::new(1),)*
23            $($ity: AtomicU32::new(1),)*
24        };
25
26        $(
27            pub(crate) struct $oty {
28                handle: handle::Handle,
29            }
30
31            impl !Send for $oty {}
32            impl !Sync for $oty {}
33
34            // Forward `Drop::drop` to the inherent `drop` method.
35            impl Drop for $oty {
36                fn drop(&mut self) {
37                    $oty {
38                        handle: self.handle,
39                    }.drop();
40                }
41            }
42
43            impl<S> Encode<S> for $oty {
44                fn encode(self, w: &mut Writer, s: &mut S) {
45                    mem::ManuallyDrop::new(self).handle.encode(w, s);
46                }
47            }
48
49            impl<S> Encode<S> for &$oty {
50                fn encode(self, w: &mut Writer, s: &mut S) {
51                    self.handle.encode(w, s);
52                }
53            }
54
55            impl<S> Encode<S> for &mut $oty {
56                fn encode(self, w: &mut Writer, s: &mut S) {
57                    self.handle.encode(w, s);
58                }
59            }
60
61            impl<S> DecodeMut<'_, '_, S> for $oty {
62                fn decode(r: &mut Reader<'_>, s: &mut S) -> Self {
63                    $oty {
64                        handle: handle::Handle::decode(r, s),
65                    }
66                }
67            }
68        )*
69
70        $(
71            #[derive(Copy, Clone, PartialEq, Eq, Hash)]
72            pub(crate) struct $ity {
73                handle: handle::Handle,
74            }
75
76            impl !Send for $ity {}
77            impl !Sync for $ity {}
78
79            impl<S> Encode<S> for $ity {
80                fn encode(self, w: &mut Writer, s: &mut S) {
81                    self.handle.encode(w, s);
82                }
83            }
84
85            impl<S> DecodeMut<'_, '_, S> for $ity {
86                fn decode(r: &mut Reader<'_>, s: &mut S) -> Self {
87                    $ity {
88                        handle: handle::Handle::decode(r, s),
89                    }
90                }
91            }
92        )*
93    }
94}
95with_api_handle_types!(define_client_handles);
96
97// FIXME(eddyb) generate these impls by pattern-matching on the
98// names of methods - also could use the presence of `fn drop`
99// to distinguish between 'owned and 'interned, above.
100// Alternatively, special "modes" could be listed of types in with_api
101// instead of pattern matching on methods, here and in server decl.
102
103impl Clone for TokenStream {
104    fn clone(&self) -> Self {
105        self.clone()
106    }
107}
108
109impl Span {
110    pub(crate) fn def_site() -> Span {
111        Bridge::with(|bridge| bridge.globals.def_site)
112    }
113
114    pub(crate) fn call_site() -> Span {
115        Bridge::with(|bridge| bridge.globals.call_site)
116    }
117
118    pub(crate) fn mixed_site() -> Span {
119        Bridge::with(|bridge| bridge.globals.mixed_site)
120    }
121}
122
123impl fmt::Debug for Span {
124    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125        f.write_str(&self.debug())
126    }
127}
128
129pub(crate) use super::symbol::Symbol;
130
131macro_rules! define_client_side {
132    ($($name:ident {
133        $(fn $method:ident($($arg:ident: $arg_ty:ty),* $(,)?) $(-> $ret_ty:ty)?;)*
134    }),* $(,)?) => {
135        $(impl $name {
136            $(pub(crate) fn $method($($arg: $arg_ty),*) $(-> $ret_ty)? {
137                Bridge::with(|bridge| {
138                    let mut buf = bridge.cached_buffer.take();
139
140                    buf.clear();
141                    api_tags::Method::$name(api_tags::$name::$method).encode(&mut buf, &mut ());
142                    $($arg.encode(&mut buf, &mut ());)*
143
144                    buf = bridge.dispatch.call(buf);
145
146                    let r = Result::<_, PanicMessage>::decode(&mut &buf[..], &mut ());
147
148                    bridge.cached_buffer = buf;
149
150                    r.unwrap_or_else(|e| panic::resume_unwind(e.into()))
151                })
152            })*
153        })*
154    }
155}
156with_api!(self, self, define_client_side);
157
158struct Bridge<'a> {
159    /// Reusable buffer (only `clear`-ed, never shrunk), primarily
160    /// used for making requests.
161    cached_buffer: Buffer,
162
163    /// Server-side function that the client uses to make requests.
164    dispatch: closure::Closure<'a, Buffer, Buffer>,
165
166    /// Provided globals for this macro expansion.
167    globals: ExpnGlobals<Span>,
168}
169
170impl<'a> !Send for Bridge<'a> {}
171impl<'a> !Sync for Bridge<'a> {}
172
173#[allow(unsafe_code)]
174mod state {
175    use std::cell::{Cell, RefCell};
176    use std::ptr;
177
178    use super::Bridge;
179
180    thread_local! {
181        static BRIDGE_STATE: Cell<*const ()> = const { Cell::new(ptr::null()) };
182    }
183
184    pub(super) fn set<'bridge, R>(state: &RefCell<Bridge<'bridge>>, f: impl FnOnce() -> R) -> R {
185        struct RestoreOnDrop(*const ());
186        impl Drop for RestoreOnDrop {
187            fn drop(&mut self) {
188                BRIDGE_STATE.set(self.0);
189            }
190        }
191
192        let inner = ptr::from_ref(state).cast();
193        let outer = BRIDGE_STATE.replace(inner);
194        let _restore = RestoreOnDrop(outer);
195
196        f()
197    }
198
199    pub(super) fn with<R>(
200        f: impl for<'bridge> FnOnce(Option<&RefCell<Bridge<'bridge>>>) -> R,
201    ) -> R {
202        let state = BRIDGE_STATE.get();
203        // SAFETY: the only place where the pointer is set is in `set`. It puts
204        // back the previous value after the inner call has returned, so we know
205        // that as long as the pointer is not null, it came from a reference to
206        // a `RefCell<Bridge>` that outlasts the call to this function. Since `f`
207        // works the same for any lifetime of the bridge, including the actual
208        // one, we can lie here and say that the lifetime is `'static` without
209        // anyone noticing.
210        let bridge = unsafe { state.cast::<RefCell<Bridge<'static>>>().as_ref() };
211        f(bridge)
212    }
213}
214
215impl Bridge<'_> {
216    fn with<R>(f: impl FnOnce(&mut Bridge<'_>) -> R) -> R {
217        state::with(|state| {
218            let bridge = state.expect("procedural macro API is used outside of a procedural macro");
219            let mut bridge = bridge
220                .try_borrow_mut()
221                .expect("procedural macro API is used while it's already in use");
222            f(&mut bridge)
223        })
224    }
225}
226
227pub(crate) fn is_available() -> bool {
228    state::with(|s| s.is_some())
229}
230
231/// A client-side RPC entry-point, which may be using a different `proc_macro`
232/// from the one used by the server, but can be invoked compatibly.
233///
234/// Note that the (phantom) `I` ("input") and `O` ("output") type parameters
235/// decorate the `Client<I, O>` with the RPC "interface" of the entry-point, but
236/// do not themselves participate in ABI, at all, only facilitate type-checking.
237///
238/// E.g. `Client<TokenStream, TokenStream>` is the common proc macro interface,
239/// used for `#[proc_macro] fn foo(input: TokenStream) -> TokenStream`,
240/// indicating that the RPC input and output will be serialized token streams,
241/// and forcing the use of APIs that take/return `S::TokenStream`, server-side.
242#[repr(C)]
243pub struct Client<I, O> {
244    pub(super) handle_counters: &'static HandleCounters,
245
246    pub(super) run: extern "C" fn(BridgeConfig<'_>) -> Buffer,
247
248    pub(super) _marker: PhantomData<fn(I) -> O>,
249}
250
251impl<I, O> Copy for Client<I, O> {}
252impl<I, O> Clone for Client<I, O> {
253    fn clone(&self) -> Self {
254        *self
255    }
256}
257
258fn maybe_install_panic_hook(force_show_panics: bool) {
259    // Hide the default panic output within `proc_macro` expansions.
260    // NB. the server can't do this because it may use a different std.
261    static HIDE_PANICS_DURING_EXPANSION: Once = Once::new();
262    HIDE_PANICS_DURING_EXPANSION.call_once(|| {
263        let prev = panic::take_hook();
264        panic::set_hook(Box::new(move |info| {
265            // We normally report panics by catching unwinds and passing the payload from the
266            // unwind back to the compiler, but if the panic doesn't unwind we'll abort before
267            // the compiler has a chance to print an error. So we special-case PanicInfo where
268            // can_unwind is false.
269            if force_show_panics || !is_available() || !info.can_unwind() {
270                prev(info)
271            }
272        }));
273    });
274}
275
276/// Client-side helper for handling client panics, entering the bridge,
277/// deserializing input and serializing output.
278// FIXME(eddyb) maybe replace `Bridge::enter` with this?
279fn run_client<A: for<'a, 's> DecodeMut<'a, 's, ()>, R: Encode<()>>(
280    config: BridgeConfig<'_>,
281    f: impl FnOnce(A) -> R,
282) -> Buffer {
283    let BridgeConfig { input: mut buf, dispatch, force_show_panics, .. } = config;
284
285    panic::catch_unwind(panic::AssertUnwindSafe(|| {
286        maybe_install_panic_hook(force_show_panics);
287
288        // Make sure the symbol store is empty before decoding inputs.
289        Symbol::invalidate_all();
290
291        let reader = &mut &buf[..];
292        let (globals, input) = <(ExpnGlobals<Span>, A)>::decode(reader, &mut ());
293
294        // Put the buffer we used for input back in the `Bridge` for requests.
295        let state = RefCell::new(Bridge { cached_buffer: buf.take(), dispatch, globals });
296
297        let output = state::set(&state, || f(input));
298
299        // Take the `cached_buffer` back out, for the output value.
300        buf = RefCell::into_inner(state).cached_buffer;
301
302        // HACK(eddyb) Separate encoding a success value (`Ok(output)`)
303        // from encoding a panic (`Err(e: PanicMessage)`) to avoid
304        // having handles outside the `bridge.enter(|| ...)` scope, and
305        // to catch panics that could happen while encoding the success.
306        //
307        // Note that panics should be impossible beyond this point, but
308        // this is defensively trying to avoid any accidental panicking
309        // reaching the `extern "C"` (which should `abort` but might not
310        // at the moment, so this is also potentially preventing UB).
311        buf.clear();
312        Ok::<_, ()>(output).encode(&mut buf, &mut ());
313    }))
314    .map_err(PanicMessage::from)
315    .unwrap_or_else(|e| {
316        buf.clear();
317        Err::<(), _>(e).encode(&mut buf, &mut ());
318    });
319
320    // Now that a response has been serialized, invalidate all symbols
321    // registered with the interner.
322    Symbol::invalidate_all();
323    buf
324}
325
326impl Client<crate::TokenStream, crate::TokenStream> {
327    pub const fn expand1(f: impl Fn(crate::TokenStream) -> crate::TokenStream + Copy) -> Self {
328        Client {
329            handle_counters: &COUNTERS,
330            run: super::selfless_reify::reify_to_extern_c_fn_hrt_bridge(move |bridge| {
331                run_client(bridge, |input| f(crate::TokenStream(Some(input))).0)
332            }),
333            _marker: PhantomData,
334        }
335    }
336}
337
338impl Client<(crate::TokenStream, crate::TokenStream), crate::TokenStream> {
339    pub const fn expand2(
340        f: impl Fn(crate::TokenStream, crate::TokenStream) -> crate::TokenStream + Copy,
341    ) -> Self {
342        Client {
343            handle_counters: &COUNTERS,
344            run: super::selfless_reify::reify_to_extern_c_fn_hrt_bridge(move |bridge| {
345                run_client(bridge, |(input, input2)| {
346                    f(crate::TokenStream(Some(input)), crate::TokenStream(Some(input2))).0
347                })
348            }),
349            _marker: PhantomData,
350        }
351    }
352}
353
354#[repr(C)]
355#[derive(Copy, Clone)]
356pub enum ProcMacro {
357    CustomDerive {
358        trait_name: &'static str,
359        attributes: &'static [&'static str],
360        client: Client<crate::TokenStream, crate::TokenStream>,
361    },
362
363    Attr {
364        name: &'static str,
365        client: Client<(crate::TokenStream, crate::TokenStream), crate::TokenStream>,
366    },
367
368    Bang {
369        name: &'static str,
370        client: Client<crate::TokenStream, crate::TokenStream>,
371    },
372}
373
374impl ProcMacro {
375    pub fn name(&self) -> &'static str {
376        match self {
377            ProcMacro::CustomDerive { trait_name, .. } => trait_name,
378            ProcMacro::Attr { name, .. } => name,
379            ProcMacro::Bang { name, .. } => name,
380        }
381    }
382
383    pub const fn custom_derive(
384        trait_name: &'static str,
385        attributes: &'static [&'static str],
386        expand: impl Fn(crate::TokenStream) -> crate::TokenStream + Copy,
387    ) -> Self {
388        ProcMacro::CustomDerive { trait_name, attributes, client: Client::expand1(expand) }
389    }
390
391    pub const fn attr(
392        name: &'static str,
393        expand: impl Fn(crate::TokenStream, crate::TokenStream) -> crate::TokenStream + Copy,
394    ) -> Self {
395        ProcMacro::Attr { name, client: Client::expand2(expand) }
396    }
397
398    pub const fn bang(
399        name: &'static str,
400        expand: impl Fn(crate::TokenStream) -> crate::TokenStream + Copy,
401    ) -> Self {
402        ProcMacro::Bang { name, client: Client::expand1(expand) }
403    }
404}