melior/ir/type/
function.rs1use super::TypeLike;
2use crate::{Context, Error, ir::Type};
3use mlir_sys::{
4 MlirType, mlirFunctionTypeGet, mlirFunctionTypeGetInput, mlirFunctionTypeGetNumInputs,
5 mlirFunctionTypeGetNumResults, mlirFunctionTypeGetResult,
6};
7
8#[derive(Clone, Copy, Debug, Hash)]
10pub struct FunctionType<'c> {
11 r#type: Type<'c>,
12}
13
14impl<'c> FunctionType<'c> {
15 pub fn new(context: &'c Context, inputs: &[Type<'c>], results: &[Type<'c>]) -> Self {
17 Self {
18 r#type: unsafe {
19 Type::from_raw(mlirFunctionTypeGet(
20 context.to_raw(),
21 inputs.len() as isize,
22 inputs as *const _ as *const _,
23 results.len() as isize,
24 results as *const _ as *const _,
25 ))
26 },
27 }
28 }
29
30 pub fn input(&self, index: usize) -> Result<Type<'c>, Error> {
32 if index < self.input_count() {
33 unsafe {
34 Ok(Type::from_raw(mlirFunctionTypeGetInput(
35 self.r#type.to_raw(),
36 index as isize,
37 )))
38 }
39 } else {
40 Err(Error::PositionOutOfBounds {
41 name: "function input",
42 value: self.to_string(),
43 index,
44 })
45 }
46 }
47
48 pub fn result(&self, index: usize) -> Result<Type<'c>, Error> {
50 if index < self.result_count() {
51 unsafe {
52 Ok(Type::from_raw(mlirFunctionTypeGetResult(
53 self.r#type.to_raw(),
54 index as isize,
55 )))
56 }
57 } else {
58 Err(Error::PositionOutOfBounds {
59 name: "function result",
60 value: self.to_string(),
61 index,
62 })
63 }
64 }
65
66 pub fn input_count(&self) -> usize {
68 unsafe { mlirFunctionTypeGetNumInputs(self.r#type.to_raw()) as usize }
69 }
70
71 pub fn result_count(&self) -> usize {
73 unsafe { mlirFunctionTypeGetNumResults(self.r#type.to_raw()) as usize }
74 }
75}
76
77type_traits!(FunctionType, is_function, "function");
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82 use crate::Context;
83
84 #[test]
85 fn new() {
86 let context = Context::new();
87 let integer = Type::index(&context);
88
89 assert_eq!(
90 Type::from(FunctionType::new(&context, &[integer, integer], &[integer])),
91 Type::parse(&context, "(index, index) -> index").unwrap()
92 );
93 }
94
95 #[test]
96 fn multiple_results() {
97 let context = Context::new();
98 let integer = Type::index(&context);
99
100 assert_eq!(
101 Type::from(FunctionType::new(&context, &[], &[integer, integer])),
102 Type::parse(&context, "() -> (index, index)").unwrap()
103 );
104 }
105
106 #[test]
107 fn input() {
108 let context = Context::new();
109 let integer = Type::index(&context);
110
111 assert_eq!(
112 FunctionType::new(&context, &[integer], &[]).input(0),
113 Ok(integer)
114 );
115 }
116
117 #[test]
118 fn input_error() {
119 let context = Context::new();
120 let integer = Type::index(&context);
121 let function = FunctionType::new(&context, &[integer], &[]);
122
123 assert_eq!(
124 function.input(42),
125 Err(Error::PositionOutOfBounds {
126 name: "function input",
127 value: function.to_string(),
128 index: 42
129 })
130 );
131 }
132
133 #[test]
134 fn result() {
135 let context = Context::new();
136 let integer = Type::index(&context);
137
138 assert_eq!(
139 FunctionType::new(&context, &[], &[integer]).result(0),
140 Ok(integer)
141 );
142 }
143
144 #[test]
145 fn result_error() {
146 let context = Context::new();
147 let integer = Type::index(&context);
148 let function = FunctionType::new(&context, &[], &[integer]);
149
150 assert_eq!(
151 function.result(42),
152 Err(Error::PositionOutOfBounds {
153 name: "function result",
154 value: function.to_string(),
155 index: 42
156 })
157 );
158 }
159
160 #[test]
161 fn input_count() {
162 let context = Context::new();
163 let integer = Type::index(&context);
164
165 assert_eq!(
166 FunctionType::new(&context, &[integer], &[]).input_count(),
167 1
168 );
169 }
170
171 #[test]
172 fn result_count() {
173 let context = Context::new();
174 let integer = Type::index(&context);
175
176 assert_eq!(
177 FunctionType::new(&context, &[], &[integer]).result_count(),
178 1
179 );
180 }
181}