tblgen/
init.rs

1// Original work Copyright 2016 Alexander Stocko <as@coder.gg>.
2// Modified work Copyright 2023 Daan Vanoverloop
3// See the COPYRIGHT file at the top-level directory of this distribution.
4//
5// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8// option. This file may not be copied, modified, or distributed
9// except according to those terms.
10
11//! This module contains smart pointers that reference various `Init` types in
12//! TableGen.
13//!
14//! Init reference types can be converted to Rust types using [`Into`] and
15//! [`TryInto`]. Most conversions are cheap, except for conversion to
16//! [`String`].
17
18use crate::{
19    raw::{
20        TableGenRecTyKind, TableGenTypedInitRef, tableGenBitInitGetValue,
21        tableGenBitsInitGetBitInit, tableGenBitsInitGetNumBits, tableGenDagRecordArgName,
22        tableGenDagRecordGet, tableGenDagRecordNumArgs, tableGenDagRecordOperator,
23        tableGenDefInitGetValue, tableGenInitPrint, tableGenInitRecType, tableGenIntInitGetValue,
24        tableGenListRecordGet, tableGenListRecordNumElements, tableGenStringInitGetValue,
25    },
26    string_ref::StringRef,
27    util::print_callback,
28};
29use paste::paste;
30
31use crate::{
32    error::{Error, TableGenError},
33    record::Record,
34};
35use std::{
36    ffi::c_void,
37    fmt::{self, Debug, Display, Formatter},
38    marker::PhantomData,
39    str::Utf8Error,
40    string::FromUtf8Error,
41};
42
43/// Enum that holds a reference to a `TypedInit`.
44#[derive(Clone, Copy, PartialEq, Eq)]
45pub enum TypedInit<'a> {
46    Bit(BitInit<'a>),
47    Bits(BitsInit<'a>),
48    Code(StringInit<'a>),
49    Int(IntInit<'a>),
50    String(StringInit<'a>),
51    List(ListInit<'a>),
52    Dag(DagInit<'a>),
53    Def(DefInit<'a>),
54    Invalid,
55}
56
57impl TypedInit<'_> {
58    fn variant_name(&self) -> &'static str {
59        match self {
60            TypedInit::Bit(_) => "Bit",
61            TypedInit::Bits(_) => "Bits",
62            TypedInit::Code(_) => "Code",
63            TypedInit::Int(_) => "Int",
64            TypedInit::String(_) => "String",
65            TypedInit::List(_) => "List",
66            TypedInit::Dag(_) => "Dag",
67            TypedInit::Def(_) => "Def",
68            TypedInit::Invalid => "Invalid",
69        }
70    }
71}
72
73impl Display for TypedInit<'_> {
74    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
75        match self {
76            Self::Bit(init) => write!(f, "{}", &init),
77            Self::Bits(init) => write!(f, "{}", &init),
78            Self::Code(init) => write!(f, "{}", &init),
79            Self::Int(init) => write!(f, "{}", &init),
80            Self::String(init) => write!(f, "{}", &init),
81            Self::List(init) => write!(f, "{}", &init),
82            Self::Dag(init) => write!(f, "{}", &init),
83            Self::Def(init) => write!(f, "{}", &init),
84            Self::Invalid => write!(f, "Invalid"),
85        }
86    }
87}
88
89impl Debug for TypedInit<'_> {
90    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
91        write!(f, "TypedInit(")?;
92        let name = self.variant_name();
93        write!(f, "{name}(")?;
94        match self {
95            Self::Bit(init) => write!(f, "{:#?}", &init),
96            Self::Bits(init) => write!(f, "{:#?}", &init),
97            Self::Code(init) => write!(f, "{:#?}", &init),
98            Self::Int(init) => write!(f, "{:#?}", &init),
99            Self::String(init) => write!(f, "{:#?}", &init),
100            Self::List(init) => write!(f, "{:#?}", &init),
101            Self::Dag(init) => write!(f, "{:#?}", &init),
102            Self::Def(init) => write!(f, "{:#?}", &init),
103            Self::Invalid => write!(f, ""),
104        }?;
105        write!(f, "))")
106    }
107}
108
109macro_rules! as_inner {
110    ($name:ident, $variant:ident, $type:ty) => {
111        paste! {
112            pub fn [<as_ $name>](self) -> Result<$type<'a>, Error> {
113                match self {
114                    Self::$variant(v) => Ok(v),
115                    _ => Err(TableGenError::InitConversion {
116                        from: self.variant_name(),
117                        to: std::any::type_name::<$type>()
118                    }.into())
119                }
120            }
121        }
122    };
123}
124
125macro_rules! try_into {
126    ($variant:ident, $init:ty, $type:ty) => {
127        impl<'a> TryFrom<TypedInit<'a>> for $type {
128            type Error = Error;
129
130            fn try_from(value: TypedInit<'a>) -> Result<Self, Self::Error> {
131                match value {
132                    TypedInit::$variant(v) => Ok(Self::try_from(v).map_err(TableGenError::from)?),
133                    _ => Err(TableGenError::InitConversion {
134                        from: value.variant_name(),
135                        to: std::any::type_name::<$type>(),
136                    }
137                    .into()),
138                }
139            }
140        }
141    };
142}
143
144try_into!(Bit, BitInit<'a>, bool);
145try_into!(Bits, BitsInit<'a>, Vec<BitInit<'a>>);
146try_into!(Bits, BitsInit<'a>, Vec<bool>);
147try_into!(Int, IntInit<'a>, i64);
148try_into!(Def, DefInit<'a>, Record<'a>);
149try_into!(List, ListInit<'a>, ListInit<'a>);
150try_into!(Dag, DagInit<'a>, DagInit<'a>);
151
152impl<'a> TryFrom<TypedInit<'a>> for String {
153    type Error = Error;
154
155    fn try_from(value: TypedInit<'a>) -> Result<Self, Self::Error> {
156        match value {
157            TypedInit::String(v) | TypedInit::Code(v) => {
158                Ok(Self::try_from(v).map_err(TableGenError::from)?)
159            }
160            _ => Err(TableGenError::InitConversion {
161                from: value.variant_name(),
162                to: std::any::type_name::<String>(),
163            }
164            .into()),
165        }
166    }
167}
168
169impl<'a> TryFrom<TypedInit<'a>> for &'a str {
170    type Error = Error;
171
172    fn try_from(value: TypedInit<'a>) -> Result<Self, Self::Error> {
173        match value {
174            TypedInit::String(v) | TypedInit::Code(v) => {
175                Ok(v.to_str().map_err(TableGenError::from)?)
176            }
177            _ => Err(TableGenError::InitConversion {
178                from: value.variant_name(),
179                to: std::any::type_name::<&'a str>(),
180            }
181            .into()),
182        }
183    }
184}
185
186impl<'a> TypedInit<'a> {
187    as_inner!(bit, Bit, BitInit);
188    as_inner!(bits, Bits, BitsInit);
189    as_inner!(code, Code, StringInit);
190    as_inner!(int, Int, IntInit);
191    as_inner!(string, String, StringInit);
192    as_inner!(list, List, ListInit);
193    as_inner!(dag, Dag, DagInit);
194    as_inner!(def, Def, DefInit);
195
196    /// Creates a new init from a raw object.
197    ///
198    /// # Safety
199    ///
200    /// The raw object must be valid.
201    #[allow(non_upper_case_globals)]
202    pub unsafe fn from_raw(init: TableGenTypedInitRef) -> Self {
203        unsafe {
204            let t = tableGenInitRecType(init);
205
206            use TableGenRecTyKind::*;
207            match t {
208                TableGenBitRecTyKind => Self::Bit(BitInit::from_raw(init)),
209                TableGenBitsRecTyKind => Self::Bits(BitsInit::from_raw(init)),
210                TableGenDagRecTyKind => TypedInit::Dag(DagInit::from_raw(init)),
211                TableGenIntRecTyKind => TypedInit::Int(IntInit::from_raw(init)),
212                TableGenListRecTyKind => TypedInit::List(ListInit::from_raw(init)),
213                TableGenRecordRecTyKind => Self::Def(DefInit::from_raw(init)),
214                TableGenStringRecTyKind => Self::String(StringInit::from_raw(init)),
215                _ => Self::Invalid,
216            }
217        }
218    }
219}
220
221macro_rules! init {
222    ($name:ident) => {
223        #[derive(Clone, Copy, PartialEq, Eq)]
224        pub struct $name<'a> {
225            raw: TableGenTypedInitRef,
226            _reference: PhantomData<&'a TableGenTypedInitRef>,
227        }
228
229        impl<'a> $name<'a> {
230            /// Creates a new init from a raw object.
231            ///
232            /// # Safety
233            ///
234            /// The raw object must be valid.
235            pub unsafe fn from_raw(raw: TableGenTypedInitRef) -> Self {
236                Self {
237                    raw,
238                    _reference: PhantomData,
239                }
240            }
241        }
242
243        impl<'a> Display for $name<'a> {
244            fn fmt(&self, formatter: &mut Formatter) -> fmt::Result {
245                let mut data = (formatter, Ok(()));
246
247                unsafe {
248                    tableGenInitPrint(
249                        self.raw,
250                        Some(print_callback),
251                        &mut data as *mut _ as *mut c_void,
252                    );
253                }
254
255                data.1
256            }
257        }
258
259        impl<'a> Debug for $name<'a> {
260            fn fmt(&self, formatter: &mut Formatter) -> fmt::Result {
261                write!(formatter, "{}(", stringify!($name))?;
262                Display::fmt(self, formatter)?;
263                write!(formatter, ")")
264            }
265        }
266    };
267}
268
269init!(BitInit);
270
271impl<'a> From<BitInit<'a>> for bool {
272    fn from(value: BitInit<'a>) -> Self {
273        let mut bit = -1;
274        unsafe { tableGenBitInitGetValue(value.raw, &mut bit) };
275        assert!(bit == 0 || bit == 1);
276        bit != 0
277    }
278}
279
280init!(BitsInit);
281
282impl<'a> From<BitsInit<'a>> for Vec<BitInit<'a>> {
283    fn from(value: BitsInit<'a>) -> Self {
284        (0..value.num_bits())
285            .map(|i| value.bit(i).expect("index within range"))
286            .collect()
287    }
288}
289
290impl<'a> From<BitsInit<'a>> for Vec<bool> {
291    fn from(value: BitsInit<'a>) -> Self {
292        (0..value.num_bits())
293            .map(|i| value.bit(i).expect("index within range").into())
294            .collect()
295    }
296}
297
298impl<'a> BitsInit<'a> {
299    /// Returns the bit at the given index.
300    pub fn bit(self, index: usize) -> Option<BitInit<'a>> {
301        let bit = unsafe { tableGenBitsInitGetBitInit(self.raw, index) };
302        if !bit.is_null() {
303            Some(unsafe { BitInit::from_raw(bit) })
304        } else {
305            None
306        }
307    }
308
309    /// Returns the number of bits in the init.
310    pub fn num_bits(self) -> usize {
311        let mut len = 0;
312        unsafe { tableGenBitsInitGetNumBits(self.raw, &mut len) };
313        len
314    }
315}
316
317init!(IntInit);
318
319impl<'a> From<IntInit<'a>> for i64 {
320    fn from(value: IntInit<'a>) -> Self {
321        let mut int: i64 = 0;
322        let res = unsafe { tableGenIntInitGetValue(value.raw, &mut int) };
323        assert!(res > 0);
324        int
325    }
326}
327
328init!(StringInit);
329
330impl<'a> TryFrom<StringInit<'a>> for String {
331    type Error = FromUtf8Error;
332
333    fn try_from(value: StringInit<'a>) -> Result<Self, Self::Error> {
334        String::from_utf8(value.as_bytes().to_vec())
335    }
336}
337
338impl<'a> TryFrom<StringInit<'a>> for &'a str {
339    type Error = Utf8Error;
340
341    fn try_from(value: StringInit<'a>) -> Result<Self, Utf8Error> {
342        value.to_str()
343    }
344}
345
346impl<'a> StringInit<'a> {
347    /// Converts the string init to a [`&str`].
348    ///
349    /// # Errors
350    ///
351    /// Returns a [`Utf8Error`] if the string init does not contain valid UTF-8.
352    pub fn to_str(self) -> Result<&'a str, Utf8Error> {
353        unsafe { StringRef::from_raw(tableGenStringInitGetValue(self.raw)) }.try_into()
354    }
355
356    /// Gets the string init as a slice of bytes.
357    pub fn as_bytes(self) -> &'a [u8] {
358        unsafe { StringRef::from_raw(tableGenStringInitGetValue(self.raw)) }.into()
359    }
360}
361
362init!(DefInit);
363
364impl<'a> From<DefInit<'a>> for Record<'a> {
365    fn from(value: DefInit<'a>) -> Self {
366        unsafe { Record::from_raw(tableGenDefInitGetValue(value.raw)) }
367    }
368}
369
370init!(DagInit);
371
372impl<'a> DagInit<'a> {
373    /// Returns an iterator over the arguments of the dag.
374    ///
375    /// The iterator yields tuples `(&str, TypedInit)`.
376    pub fn args(self) -> DagIter<'a> {
377        DagIter {
378            dag: self,
379            index: 0,
380        }
381    }
382
383    /// Returns the operator of the dag as a [`Record`].
384    pub fn operator(self) -> Record<'a> {
385        unsafe { Record::from_raw(tableGenDagRecordOperator(self.raw)) }
386    }
387
388    /// Returns the number of arguments for this dag.
389    pub fn num_args(self) -> usize {
390        unsafe { tableGenDagRecordNumArgs(self.raw) }
391    }
392
393    /// Returns the name of the argument at the given index.
394    pub fn name(self, index: usize) -> Option<&'a str> {
395        unsafe { StringRef::from_option_raw(tableGenDagRecordArgName(self.raw, index)) }
396            .and_then(|s| s.try_into().ok())
397    }
398
399    /// Returns the argument at the given index.
400    pub fn get(self, index: usize) -> Option<TypedInit<'a>> {
401        let value = unsafe { tableGenDagRecordGet(self.raw, index) };
402        if !value.is_null() {
403            Some(unsafe { TypedInit::from_raw(value) })
404        } else {
405            None
406        }
407    }
408}
409
410#[derive(Debug, Clone)]
411pub struct DagIter<'a> {
412    dag: DagInit<'a>,
413    index: usize,
414}
415
416impl<'a> Iterator for DagIter<'a> {
417    type Item = (&'a str, TypedInit<'a>);
418
419    fn next(&mut self) -> Option<Self::Item> {
420        let next = self.dag.get(self.index);
421        let name = self.dag.name(self.index);
422        self.index += 1;
423        if let (Some(next), Some(name)) = (next, name) {
424            Some((name, next))
425        } else {
426            None
427        }
428    }
429}
430
431init!(ListInit);
432
433impl<'a> ListInit<'a> {
434    /// Returns an iterator over the elements of the list.
435    ///
436    /// The iterator yields values of type [`TypedInit`].
437    pub fn iter(self) -> ListIter<'a> {
438        ListIter {
439            list: self,
440            index: 0,
441        }
442    }
443
444    /// Returns true if the list is empty.
445    pub fn is_empty(self) -> bool {
446        self.len() == 0
447    }
448
449    /// Returns the length of the list.
450    pub fn len(self) -> usize {
451        unsafe { tableGenListRecordNumElements(self.raw) }
452    }
453
454    /// Returns the element at the given index in the list.
455    pub fn get(self, index: usize) -> Option<TypedInit<'a>> {
456        let value = unsafe { tableGenListRecordGet(self.raw, index) };
457        if !value.is_null() {
458            Some(unsafe { TypedInit::from_raw(value) })
459        } else {
460            None
461        }
462    }
463}
464
465#[derive(Debug, Clone)]
466pub struct ListIter<'a> {
467    list: ListInit<'a>,
468    index: usize,
469}
470
471impl<'a> Iterator for ListIter<'a> {
472    type Item = TypedInit<'a>;
473
474    fn next(&mut self) -> Option<TypedInit<'a>> {
475        let next = unsafe { tableGenListRecordGet(self.list.raw, self.index) };
476        self.index += 1;
477        if !next.is_null() {
478            Some(unsafe { TypedInit::from_raw(next) })
479        } else {
480            None
481        }
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488    use crate::TableGenParser;
489
490    macro_rules! test_init {
491        ($name:ident, $td_field:expr, $expected:expr) => {
492            #[test]
493            fn $name() {
494                let rk = TableGenParser::new()
495                    .add_source(&format!(
496                        "
497                    def A {{
498                        {}
499                    }}
500                    ",
501                        $td_field
502                    ))
503                    .unwrap()
504                    .parse()
505                    .expect("valid tablegen");
506                let a = rk
507                    .def("A")
508                    .expect("def A exists")
509                    .value("a")
510                    .expect("field a exists");
511                assert_eq!(a.init.try_into(), Ok($expected));
512            }
513        };
514    }
515
516    test_init!(bit, "bit a = 0;", false);
517    test_init!(
518        bits,
519        "bits<4> a = { 0, 0, 1, 0 };",
520        vec![false, true, false, false]
521    );
522    test_init!(int, "int a = 42;", 42);
523    test_init!(string, "string a = \"hi\";", "hi");
524
525    #[test]
526    fn dag() {
527        let rk = TableGenParser::new()
528            .add_source(
529                "
530                def ins;
531                def X {
532                    int i = 4;
533                }
534                def Y {
535                    string s = \"test\";
536                }
537                def A {
538                    dag args = (ins X:$src1, Y:$src2);
539                }
540                ",
541            )
542            .unwrap()
543            .parse()
544            .expect("valid tablegen");
545        let a: DagInit = rk
546            .def("A")
547            .expect("def A exists")
548            .value("args")
549            .expect("field args exists")
550            .try_into()
551            .expect("is dag init");
552        assert_eq!(a.num_args(), 2);
553        assert_eq!(a.operator().name(), Ok("ins"));
554        let mut args = a.args();
555        assert_eq!(
556            args.clone().next().map(|(name, init)| (
557                name,
558                Record::try_from(init).expect("is record").int_value("i")
559            )),
560            Some(("src1", Ok(4)))
561        );
562        assert_eq!(
563            args.nth(1).map(|(name, init)| (
564                name,
565                Record::try_from(init).expect("is record").string_value("s")
566            )),
567            Some(("src2", Ok("test".into())))
568        );
569    }
570
571    #[test]
572    fn list() {
573        let rk = TableGenParser::new()
574            .add_source(
575                "
576                def A {
577                    list<int> l = [0, 1, 2, 3];
578                }
579                ",
580            )
581            .unwrap()
582            .parse()
583            .expect("valid tablegen");
584        let l: ListInit = rk
585            .def("A")
586            .expect("def A exists")
587            .value("l")
588            .expect("field args exists")
589            .try_into()
590            .expect("is list init");
591        assert_eq!(l.len(), 4);
592        let iter = l.iter();
593        assert_eq!(iter.clone().count(), 4);
594        assert_eq!(iter.clone().next().unwrap().try_into(), Ok(0));
595        assert_eq!(iter.clone().nth(1).unwrap().try_into(), Ok(1));
596        assert_eq!(iter.clone().nth(2).unwrap().try_into(), Ok(2));
597        assert_eq!(iter.clone().nth(3).unwrap().try_into(), Ok(3));
598    }
599}