melior/ir/attribute/
dictionary.rs

1use super::{Attribute, AttributeLike};
2use crate::{Context, Error, StringRef, ir::Identifier};
3use mlir_sys::{
4    MlirAttribute, mlirDictionaryAttrGet, mlirDictionaryAttrGetElement,
5    mlirDictionaryAttrGetElementByName, mlirDictionaryAttrGetNumElements, mlirNamedAttributeGet,
6};
7
8/// A dictionary attribute.
9#[derive(Clone, Copy, Hash)]
10pub struct DictionaryAttribute<'c> {
11    attribute: Attribute<'c>,
12}
13
14impl<'c> DictionaryAttribute<'c> {
15    /// Creates a dictionary attribute.
16    pub fn new(context: &'c Context, elements: &[(Identifier<'c>, Attribute<'c>)]) -> Self {
17        let named: Vec<_> = elements
18            .iter()
19            .map(|(id, attr)| unsafe { mlirNamedAttributeGet(id.to_raw(), attr.to_raw()) })
20            .collect();
21        unsafe {
22            Self::from_raw(mlirDictionaryAttrGet(
23                context.to_raw(),
24                named.len() as isize,
25                named.as_ptr(),
26            ))
27        }
28    }
29
30    /// Returns the number of elements.
31    pub fn len(&self) -> usize {
32        (unsafe { mlirDictionaryAttrGetNumElements(self.to_raw()) }) as usize
33    }
34
35    /// Checks if the dictionary attribute is empty.
36    pub fn is_empty(&self) -> bool {
37        self.len() == 0
38    }
39
40    /// Returns the element at the given index.
41    pub fn element(&self, index: usize) -> Result<(Identifier<'c>, Attribute<'c>), Error> {
42        if index < self.len() {
43            let named = unsafe { mlirDictionaryAttrGetElement(self.to_raw(), index as isize) };
44            Ok(unsafe {
45                (
46                    Identifier::from_raw(named.name),
47                    Attribute::from_raw(named.attribute),
48                )
49            })
50        } else {
51            Err(Error::PositionOutOfBounds {
52                name: "dictionary element",
53                value: self.to_string(),
54                index,
55            })
56        }
57    }
58
59    /// Returns the attribute with the given name, or `None` if not found.
60    pub fn element_by_name(&self, name: &str) -> Option<Attribute<'c>> {
61        unsafe {
62            Attribute::from_option_raw(mlirDictionaryAttrGetElementByName(
63                self.to_raw(),
64                StringRef::new(name).to_raw(),
65            ))
66        }
67    }
68}
69
70attribute_traits!(DictionaryAttribute, is_dictionary, "dictionary");
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75    use crate::{
76        ir::{attribute::IntegerAttribute, r#type::IntegerType},
77        test::create_test_context,
78    };
79
80    #[test]
81    fn new_empty() {
82        let context = create_test_context();
83        let attribute = DictionaryAttribute::new(&context, &[]);
84
85        assert!(attribute.is_empty());
86        assert_eq!(attribute.len(), 0);
87    }
88
89    #[test]
90    fn len() {
91        let context = create_test_context();
92        let id = Identifier::new(&context, "foo");
93        let val = IntegerAttribute::new(IntegerType::new(&context, 64).into(), 42).into();
94        let attribute = DictionaryAttribute::new(&context, &[(id, val)]);
95
96        assert_eq!(attribute.len(), 1);
97        assert!(!attribute.is_empty());
98    }
99
100    #[test]
101    fn element() {
102        let context = create_test_context();
103        let id = Identifier::new(&context, "bar");
104        let val = IntegerAttribute::new(IntegerType::new(&context, 64).into(), 7).into();
105        let attribute = DictionaryAttribute::new(&context, &[(id, val)]);
106
107        let (got_id, got_val) = attribute.element(0).unwrap();
108        assert_eq!(got_id.as_string_ref().as_str().unwrap(), "bar");
109        assert_eq!(got_val, val);
110        assert!(matches!(
111            attribute.element(1),
112            Err(Error::PositionOutOfBounds { .. })
113        ));
114    }
115
116    #[test]
117    fn element_by_name() {
118        let context = create_test_context();
119        let id = Identifier::new(&context, "baz");
120        let val = IntegerAttribute::new(IntegerType::new(&context, 64).into(), 99).into();
121        let attribute = DictionaryAttribute::new(&context, &[(id, val)]);
122
123        assert_eq!(attribute.element_by_name("baz"), Some(val));
124        assert_eq!(attribute.element_by_name("missing"), None);
125    }
126}