1#![doc = include_str!("../README.md")]
2
3extern crate self as melior;
4
5#[macro_use]
6mod r#macro;
7mod context;
8pub mod diagnostic;
9pub mod dialect;
10mod error;
11mod execution_engine;
12mod greedy_rewrite_driver;
13#[cfg(feature = "helpers")]
14pub mod helpers;
15pub mod ir;
16mod ir_rewriter;
17mod logical_result;
18pub mod pass;
19mod rewrite_pattern;
20mod string_ref;
21mod thread_pool;
22
23#[cfg(test)]
24mod test;
25pub mod utility;
26
27pub use self::{
28 context::{Context, ContextRef},
29 error::Error,
30 execution_engine::ExecutionEngine,
31 greedy_rewrite_driver::{
32 GreedyRewriteDriverConfig, GreedyRewriteStrictness, GreedySimplifyRegionLevel,
33 apply_patterns_and_fold_greedily, walk_and_apply_patterns,
34 },
35 ir_rewriter::{IrRewriter, RewriterBase},
36 rewrite_pattern::{
37 FrozenRewritePatternSet, PatternRewriter, RewritePattern, RewritePatternSet,
38 create_op_rewrite_pattern,
39 },
40 string_ref::StringRef,
41 thread_pool::ThreadPool,
42};
43
44pub use melior_macro::dialect;
45
46#[cfg(test)]
47mod tests {
48 use crate::{
49 context::Context,
50 dialect::{self, arith, func, scf},
51 ir::{
52 Block, BlockLike, Location, Module, Region, RegionLike, Type, Value,
53 attribute::{IntegerAttribute, StringAttribute, TypeAttribute},
54 operation::{OperationBuilder, OperationLike},
55 r#type::{FunctionType, IntegerType},
56 },
57 test::load_all_dialects,
58 };
59
60 #[test]
61 fn build_module() {
62 let context = Context::new();
63 let module = Module::new(Location::unknown(&context));
64
65 assert!(module.as_operation().verify());
66 insta::assert_snapshot!(module.as_operation());
67 }
68
69 #[test]
70 fn build_module_with_dialect() {
71 let registry = dialect::DialectRegistry::new();
72 let context = Context::new();
73 context.append_dialect_registry(®istry);
74 let module = Module::new(Location::unknown(&context));
75
76 assert!(module.as_operation().verify());
77 insta::assert_snapshot!(module.as_operation());
78 }
79
80 #[test]
81 fn build_add() {
82 let context = Context::new();
83 load_all_dialects(&context);
84
85 let location = Location::unknown(&context);
86 let module = Module::new(location);
87
88 let integer_type = IntegerType::new(&context, 64).into();
89
90 let function = {
91 let block = Block::new(&[(integer_type, location), (integer_type, location)]);
92
93 let sum = block.append_operation(arith::addi(
94 block.argument(0).unwrap().into(),
95 block.argument(1).unwrap().into(),
96 location,
97 ));
98
99 block.append_operation(func::r#return(&[sum.result(0).unwrap().into()], location));
100
101 let region = Region::new();
102 region.append_block(block);
103
104 func::func(
105 &context,
106 StringAttribute::new(&context, "add"),
107 TypeAttribute::new(
108 FunctionType::new(&context, &[integer_type, integer_type], &[integer_type])
109 .into(),
110 ),
111 region,
112 &[],
113 Location::unknown(&context),
114 )
115 };
116
117 module.body().append_operation(function);
118
119 assert!(module.as_operation().verify());
120 insta::assert_snapshot!(module.as_operation());
121 }
122
123 #[test]
124 fn build_sum() {
125 let context = Context::new();
126 load_all_dialects(&context);
127
128 let location = Location::unknown(&context);
129 let module = Module::new(location);
130
131 let memref_type = Type::parse(&context, "memref<?xf32>").unwrap();
132
133 let function = {
134 let function_block = Block::new(&[(memref_type, location), (memref_type, location)]);
135 let index_type = Type::parse(&context, "index").unwrap();
136
137 let zero = function_block.append_operation(arith::constant(
138 &context,
139 IntegerAttribute::new(Type::index(&context), 0).into(),
140 location,
141 ));
142
143 let dim = function_block.append_operation(
144 OperationBuilder::new("memref.dim", location)
145 .add_operands(&[
146 function_block.argument(0).unwrap().into(),
147 zero.result(0).unwrap().into(),
148 ])
149 .add_results(&[index_type])
150 .build()
151 .unwrap(),
152 );
153
154 let loop_block = Block::new(&[(index_type, location)]);
155
156 let one = function_block.append_operation(arith::constant(
157 &context,
158 IntegerAttribute::new(Type::index(&context), 1).into(),
159 location,
160 ));
161
162 {
163 let f32_type = Type::float32(&context);
164
165 let lhs = loop_block.append_operation(
166 OperationBuilder::new("memref.load", location)
167 .add_operands(&[
168 function_block.argument(0).unwrap().into(),
169 loop_block.argument(0).unwrap().into(),
170 ])
171 .add_results(&[f32_type])
172 .build()
173 .unwrap(),
174 );
175
176 let rhs = loop_block.append_operation(
177 OperationBuilder::new("memref.load", location)
178 .add_operands(&[
179 function_block.argument(1).unwrap().into(),
180 loop_block.argument(0).unwrap().into(),
181 ])
182 .add_results(&[f32_type])
183 .build()
184 .unwrap(),
185 );
186
187 let add = loop_block.append_operation(arith::addf(
188 lhs.result(0).unwrap().into(),
189 rhs.result(0).unwrap().into(),
190 location,
191 ));
192
193 loop_block.append_operation(
194 OperationBuilder::new("memref.store", location)
195 .add_operands(&[
196 add.result(0).unwrap().into(),
197 function_block.argument(0).unwrap().into(),
198 loop_block.argument(0).unwrap().into(),
199 ])
200 .build()
201 .unwrap(),
202 );
203
204 loop_block.append_operation(scf::r#yield(&[], location));
205 }
206
207 function_block.append_operation(scf::r#for(
208 zero.result(0).unwrap().into(),
209 dim.result(0).unwrap().into(),
210 one.result(0).unwrap().into(),
211 {
212 let loop_region = Region::new();
213 loop_region.append_block(loop_block);
214 loop_region
215 },
216 location,
217 ));
218
219 function_block.append_operation(func::r#return(&[], location));
220
221 let function_region = Region::new();
222 function_region.append_block(function_block);
223
224 func::func(
225 &context,
226 StringAttribute::new(&context, "sum"),
227 TypeAttribute::new(
228 FunctionType::new(&context, &[memref_type, memref_type], &[]).into(),
229 ),
230 function_region,
231 &[],
232 Location::unknown(&context),
233 )
234 };
235
236 module.body().append_operation(function);
237
238 assert!(module.as_operation().verify());
239 insta::assert_snapshot!(module.as_operation());
240 }
241
242 #[test]
243 fn return_value_from_function() {
244 let context = Context::new();
245 load_all_dialects(&context);
246
247 let location = Location::unknown(&context);
248 let module = Module::new(location);
249
250 let integer_type = IntegerType::new(&context, 64).into();
251
252 fn compile_add<'c, 'a>(
253 context: &'c Context,
254 block: &'a Block<'c>,
255 lhs: Value<'c, '_>,
256 rhs: Value<'c, '_>,
257 ) -> Value<'c, 'a> {
258 block
259 .append_operation(arith::addi(lhs, rhs, Location::unknown(context)))
260 .result(0)
261 .unwrap()
262 .into()
263 }
264
265 module.body().append_operation(func::func(
266 &context,
267 StringAttribute::new(&context, "add"),
268 TypeAttribute::new(
269 FunctionType::new(&context, &[integer_type, integer_type], &[integer_type]).into(),
270 ),
271 {
272 let block = Block::new(&[(integer_type, location), (integer_type, location)]);
273
274 block.append_operation(func::r#return(
275 &[compile_add(
276 &context,
277 &block,
278 block.argument(0).unwrap().into(),
279 block.argument(1).unwrap().into(),
280 )],
281 location,
282 ));
283
284 let region = Region::new();
285 region.append_block(block);
286 region
287 },
288 &[],
289 Location::unknown(&context),
290 ));
291
292 assert!(module.as_operation().verify());
293 insta::assert_snapshot!(module.as_operation());
294 }
295}