melior/ir/type/
function.rs

1use super::TypeLike;
2use crate::{Context, Error, ir::Type};
3use mlir_sys::{
4    MlirType, mlirFunctionTypeGet, mlirFunctionTypeGetInput, mlirFunctionTypeGetNumInputs,
5    mlirFunctionTypeGetNumResults, mlirFunctionTypeGetResult,
6};
7
8/// A function type.
9#[derive(Clone, Copy, Debug, Hash)]
10pub struct FunctionType<'c> {
11    r#type: Type<'c>,
12}
13
14impl<'c> FunctionType<'c> {
15    /// Creates a function type.
16    pub fn new(context: &'c Context, inputs: &[Type<'c>], results: &[Type<'c>]) -> Self {
17        Self {
18            r#type: unsafe {
19                Type::from_raw(mlirFunctionTypeGet(
20                    context.to_raw(),
21                    inputs.len() as isize,
22                    inputs as *const _ as *const _,
23                    results.len() as isize,
24                    results as *const _ as *const _,
25                ))
26            },
27        }
28    }
29
30    /// Returns an input at a position.
31    pub fn input(&self, index: usize) -> Result<Type<'c>, Error> {
32        if index < self.input_count() {
33            unsafe {
34                Ok(Type::from_raw(mlirFunctionTypeGetInput(
35                    self.r#type.to_raw(),
36                    index as isize,
37                )))
38            }
39        } else {
40            Err(Error::PositionOutOfBounds {
41                name: "function input",
42                value: self.to_string(),
43                index,
44            })
45        }
46    }
47
48    /// Returns a result at a position.
49    pub fn result(&self, index: usize) -> Result<Type<'c>, Error> {
50        if index < self.result_count() {
51            unsafe {
52                Ok(Type::from_raw(mlirFunctionTypeGetResult(
53                    self.r#type.to_raw(),
54                    index as isize,
55                )))
56            }
57        } else {
58            Err(Error::PositionOutOfBounds {
59                name: "function result",
60                value: self.to_string(),
61                index,
62            })
63        }
64    }
65
66    /// Returns a number of inputs.
67    pub fn input_count(&self) -> usize {
68        unsafe { mlirFunctionTypeGetNumInputs(self.r#type.to_raw()) as usize }
69    }
70
71    /// Returns a number of results.
72    pub fn result_count(&self) -> usize {
73        unsafe { mlirFunctionTypeGetNumResults(self.r#type.to_raw()) as usize }
74    }
75}
76
77type_traits!(FunctionType, is_function, "function");
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82    use crate::Context;
83
84    #[test]
85    fn new() {
86        let context = Context::new();
87        let integer = Type::index(&context);
88
89        assert_eq!(
90            Type::from(FunctionType::new(&context, &[integer, integer], &[integer])),
91            Type::parse(&context, "(index, index) -> index").unwrap()
92        );
93    }
94
95    #[test]
96    fn multiple_results() {
97        let context = Context::new();
98        let integer = Type::index(&context);
99
100        assert_eq!(
101            Type::from(FunctionType::new(&context, &[], &[integer, integer])),
102            Type::parse(&context, "() -> (index, index)").unwrap()
103        );
104    }
105
106    #[test]
107    fn input() {
108        let context = Context::new();
109        let integer = Type::index(&context);
110
111        assert_eq!(
112            FunctionType::new(&context, &[integer], &[]).input(0),
113            Ok(integer)
114        );
115    }
116
117    #[test]
118    fn input_error() {
119        let context = Context::new();
120        let integer = Type::index(&context);
121        let function = FunctionType::new(&context, &[integer], &[]);
122
123        assert_eq!(
124            function.input(42),
125            Err(Error::PositionOutOfBounds {
126                name: "function input",
127                value: function.to_string(),
128                index: 42
129            })
130        );
131    }
132
133    #[test]
134    fn result() {
135        let context = Context::new();
136        let integer = Type::index(&context);
137
138        assert_eq!(
139            FunctionType::new(&context, &[], &[integer]).result(0),
140            Ok(integer)
141        );
142    }
143
144    #[test]
145    fn result_error() {
146        let context = Context::new();
147        let integer = Type::index(&context);
148        let function = FunctionType::new(&context, &[], &[integer]);
149
150        assert_eq!(
151            function.result(42),
152            Err(Error::PositionOutOfBounds {
153                name: "function result",
154                value: function.to_string(),
155                index: 42
156            })
157        );
158    }
159
160    #[test]
161    fn input_count() {
162        let context = Context::new();
163        let integer = Type::index(&context);
164
165        assert_eq!(
166            FunctionType::new(&context, &[integer], &[]).input_count(),
167            1
168        );
169    }
170
171    #[test]
172    fn result_count() {
173        let context = Context::new();
174        let integer = Type::index(&context);
175
176        assert_eq!(
177            FunctionType::new(&context, &[], &[integer]).result_count(),
178            1
179        );
180    }
181}