1pub mod arith;
4pub mod cf;
5pub mod func;
6mod handle;
7pub mod index;
8pub mod llvm;
9pub mod memref;
10mod registry;
11pub mod scf;
12
13pub use self::{handle::DialectHandle, registry::DialectRegistry};
14use crate::{
15 context::{Context, ContextRef},
16 string_ref::StringRef,
17};
18use mlir_sys::{MlirDialect, mlirDialectEqual, mlirDialectGetContext, mlirDialectGetNamespace};
19use std::{marker::PhantomData, str::Utf8Error};
20
21#[cfg(feature = "ods-dialects")]
22pub mod ods;
23
24#[derive(Clone, Copy, Debug)]
26pub struct Dialect<'c> {
27 raw: MlirDialect,
28 _context: PhantomData<&'c Context>,
29}
30
31impl<'c> Dialect<'c> {
32 pub fn context(&self) -> ContextRef<'c> {
34 unsafe { ContextRef::from_raw(mlirDialectGetContext(self.raw)) }
35 }
36
37 pub fn namespace(&self) -> Result<&str, Utf8Error> {
39 unsafe { StringRef::from_raw(mlirDialectGetNamespace(self.raw)) }.as_str()
40 }
41
42 pub unsafe fn from_raw(dialect: MlirDialect) -> Self {
48 Self {
49 raw: dialect,
50 _context: Default::default(),
51 }
52 }
53}
54
55impl PartialEq for Dialect<'_> {
56 fn eq(&self, other: &Self) -> bool {
57 unsafe { mlirDialectEqual(self.raw, other.raw) }
58 }
59}
60
61impl Eq for Dialect<'_> {}
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66
67 #[test]
68 fn namespace() {
69 let context = Context::new();
70
71 assert_eq!(
72 DialectHandle::llvm()
73 .load_dialect(&context)
74 .namespace()
75 .unwrap(),
76 "llvm"
77 );
78 }
79
80 #[test]
81 fn equal() {
82 let context = Context::new();
83
84 assert_eq!(
85 DialectHandle::func().load_dialect(&context),
86 DialectHandle::func().load_dialect(&context)
87 );
88 }
89
90 #[test]
91 fn not_equal() {
92 let context = Context::new();
93
94 assert_ne!(
95 DialectHandle::func().load_dialect(&context),
96 DialectHandle::llvm().load_dialect(&context)
97 );
98 }
99}