melior/
lib.rs

1#![doc = include_str!("../README.md")]
2
3extern crate self as melior;
4
5#[macro_use]
6mod r#macro;
7mod context;
8pub mod diagnostic;
9pub mod dialect;
10mod error;
11mod execution_engine;
12mod greedy_rewrite_driver;
13#[cfg(feature = "helpers")]
14pub mod helpers;
15pub mod ir;
16mod ir_rewriter;
17mod logical_result;
18pub mod pass;
19mod rewrite_pattern;
20mod string_ref;
21mod thread_pool;
22
23#[cfg(test)]
24mod test;
25pub mod utility;
26
27pub use self::{
28    context::{Context, ContextRef},
29    error::Error,
30    execution_engine::ExecutionEngine,
31    greedy_rewrite_driver::{
32        GreedyRewriteDriverConfig, GreedyRewriteStrictness, GreedySimplifyRegionLevel,
33        apply_patterns_and_fold_greedily, walk_and_apply_patterns,
34    },
35    ir_rewriter::{IrRewriter, RewriterBase},
36    rewrite_pattern::{
37        FrozenRewritePatternSet, PatternRewriter, RewritePattern, RewritePatternSet,
38        create_op_rewrite_pattern,
39    },
40    string_ref::StringRef,
41    thread_pool::ThreadPool,
42};
43
44pub use melior_macro::dialect;
45
46#[cfg(test)]
47mod tests {
48    use crate::{
49        context::Context,
50        dialect::{self, arith, func, scf},
51        ir::{
52            Block, BlockLike, Location, Module, Region, RegionLike, Type, Value,
53            attribute::{IntegerAttribute, StringAttribute, TypeAttribute},
54            operation::{OperationBuilder, OperationLike},
55            r#type::{FunctionType, IntegerType},
56        },
57        test::load_all_dialects,
58    };
59
60    #[test]
61    fn build_module() {
62        let context = Context::new();
63        let module = Module::new(Location::unknown(&context));
64
65        assert!(module.as_operation().verify());
66        insta::assert_snapshot!(module.as_operation());
67    }
68
69    #[test]
70    fn build_module_with_dialect() {
71        let registry = dialect::DialectRegistry::new();
72        let context = Context::new();
73        context.append_dialect_registry(&registry);
74        let module = Module::new(Location::unknown(&context));
75
76        assert!(module.as_operation().verify());
77        insta::assert_snapshot!(module.as_operation());
78    }
79
80    #[test]
81    fn build_add() {
82        let context = Context::new();
83        load_all_dialects(&context);
84
85        let location = Location::unknown(&context);
86        let module = Module::new(location);
87
88        let integer_type = IntegerType::new(&context, 64).into();
89
90        let function = {
91            let block = Block::new(&[(integer_type, location), (integer_type, location)]);
92
93            let sum = block.append_operation(arith::addi(
94                block.argument(0).unwrap().into(),
95                block.argument(1).unwrap().into(),
96                location,
97            ));
98
99            block.append_operation(func::r#return(&[sum.result(0).unwrap().into()], location));
100
101            let region = Region::new();
102            region.append_block(block);
103
104            func::func(
105                &context,
106                StringAttribute::new(&context, "add"),
107                TypeAttribute::new(
108                    FunctionType::new(&context, &[integer_type, integer_type], &[integer_type])
109                        .into(),
110                ),
111                region,
112                &[],
113                Location::unknown(&context),
114            )
115        };
116
117        module.body().append_operation(function);
118
119        assert!(module.as_operation().verify());
120        insta::assert_snapshot!(module.as_operation());
121    }
122
123    #[test]
124    fn build_sum() {
125        let context = Context::new();
126        load_all_dialects(&context);
127
128        let location = Location::unknown(&context);
129        let module = Module::new(location);
130
131        let memref_type = Type::parse(&context, "memref<?xf32>").unwrap();
132
133        let function = {
134            let function_block = Block::new(&[(memref_type, location), (memref_type, location)]);
135            let index_type = Type::parse(&context, "index").unwrap();
136
137            let zero = function_block.append_operation(arith::constant(
138                &context,
139                IntegerAttribute::new(Type::index(&context), 0).into(),
140                location,
141            ));
142
143            let dim = function_block.append_operation(
144                OperationBuilder::new("memref.dim", location)
145                    .add_operands(&[
146                        function_block.argument(0).unwrap().into(),
147                        zero.result(0).unwrap().into(),
148                    ])
149                    .add_results(&[index_type])
150                    .build()
151                    .unwrap(),
152            );
153
154            let loop_block = Block::new(&[(index_type, location)]);
155
156            let one = function_block.append_operation(arith::constant(
157                &context,
158                IntegerAttribute::new(Type::index(&context), 1).into(),
159                location,
160            ));
161
162            {
163                let f32_type = Type::float32(&context);
164
165                let lhs = loop_block.append_operation(
166                    OperationBuilder::new("memref.load", location)
167                        .add_operands(&[
168                            function_block.argument(0).unwrap().into(),
169                            loop_block.argument(0).unwrap().into(),
170                        ])
171                        .add_results(&[f32_type])
172                        .build()
173                        .unwrap(),
174                );
175
176                let rhs = loop_block.append_operation(
177                    OperationBuilder::new("memref.load", location)
178                        .add_operands(&[
179                            function_block.argument(1).unwrap().into(),
180                            loop_block.argument(0).unwrap().into(),
181                        ])
182                        .add_results(&[f32_type])
183                        .build()
184                        .unwrap(),
185                );
186
187                let add = loop_block.append_operation(arith::addf(
188                    lhs.result(0).unwrap().into(),
189                    rhs.result(0).unwrap().into(),
190                    location,
191                ));
192
193                loop_block.append_operation(
194                    OperationBuilder::new("memref.store", location)
195                        .add_operands(&[
196                            add.result(0).unwrap().into(),
197                            function_block.argument(0).unwrap().into(),
198                            loop_block.argument(0).unwrap().into(),
199                        ])
200                        .build()
201                        .unwrap(),
202                );
203
204                loop_block.append_operation(scf::r#yield(&[], location));
205            }
206
207            function_block.append_operation(scf::r#for(
208                zero.result(0).unwrap().into(),
209                dim.result(0).unwrap().into(),
210                one.result(0).unwrap().into(),
211                {
212                    let loop_region = Region::new();
213                    loop_region.append_block(loop_block);
214                    loop_region
215                },
216                location,
217            ));
218
219            function_block.append_operation(func::r#return(&[], location));
220
221            let function_region = Region::new();
222            function_region.append_block(function_block);
223
224            func::func(
225                &context,
226                StringAttribute::new(&context, "sum"),
227                TypeAttribute::new(
228                    FunctionType::new(&context, &[memref_type, memref_type], &[]).into(),
229                ),
230                function_region,
231                &[],
232                Location::unknown(&context),
233            )
234        };
235
236        module.body().append_operation(function);
237
238        assert!(module.as_operation().verify());
239        insta::assert_snapshot!(module.as_operation());
240    }
241
242    #[test]
243    fn return_value_from_function() {
244        let context = Context::new();
245        load_all_dialects(&context);
246
247        let location = Location::unknown(&context);
248        let module = Module::new(location);
249
250        let integer_type = IntegerType::new(&context, 64).into();
251
252        fn compile_add<'c, 'a>(
253            context: &'c Context,
254            block: &'a Block<'c>,
255            lhs: Value<'c, '_>,
256            rhs: Value<'c, '_>,
257        ) -> Value<'c, 'a> {
258            block
259                .append_operation(arith::addi(lhs, rhs, Location::unknown(context)))
260                .result(0)
261                .unwrap()
262                .into()
263        }
264
265        module.body().append_operation(func::func(
266            &context,
267            StringAttribute::new(&context, "add"),
268            TypeAttribute::new(
269                FunctionType::new(&context, &[integer_type, integer_type], &[integer_type]).into(),
270            ),
271            {
272                let block = Block::new(&[(integer_type, location), (integer_type, location)]);
273
274                block.append_operation(func::r#return(
275                    &[compile_add(
276                        &context,
277                        &block,
278                        block.argument(0).unwrap().into(),
279                        block.argument(1).unwrap().into(),
280                    )],
281                    location,
282                ));
283
284                let region = Region::new();
285                region.append_block(block);
286                region
287            },
288            &[],
289            Location::unknown(&context),
290        ));
291
292        assert!(module.as_operation().verify());
293        insta::assert_snapshot!(module.as_operation());
294    }
295}