Skip to content

Commit e10ee08

Browse files
mgr0dzickitudny
andauthored
Add functions to parser (#57)
* Add functions * clippy * added expr tests in parser * Update error * added identity to parser tests * testing existance of std functions * fmt --------- Co-authored-by: Aleksander Tudruj <[email protected]>
1 parent 0975c2d commit e10ee08

File tree

3 files changed

+222
-35
lines changed

3 files changed

+222
-35
lines changed

src/environment.rs

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::collections::btree_map::IterMut;
22
use std::collections::BTreeMap;
33

4-
use anyhow::bail;
4+
use anyhow::{bail, Context};
55

66
use crate::locale::Locale;
77
use crate::traits::{GuiDisplayable, LaTeXable};
@@ -23,6 +23,10 @@ impl Identifier {
2323
}
2424
}
2525

26+
fn new_unsafe(id: String) -> Self {
27+
Self { id }
28+
}
29+
2630
pub fn result() -> Self {
2731
Self {
2832
id: Self::RESULT.to_string(),
@@ -64,6 +68,20 @@ impl<T: MatrixNumber> Type<T> {
6468
pub fn from_matrix_result(opt: anyhow::Result<Matrix<T>>) -> anyhow::Result<Self> {
6569
Ok(Self::Matrix(opt?))
6670
}
71+
72+
pub fn into_scalar(self) -> anyhow::Result<T> {
73+
match self {
74+
Type::Scalar(s) => Ok(s),
75+
Type::Matrix(_) => bail!("Expected scalar, got matrix."),
76+
}
77+
}
78+
79+
pub fn into_matrix(self) -> anyhow::Result<Matrix<T>> {
80+
match self {
81+
Type::Matrix(m) => Ok(m),
82+
Type::Scalar(_) => bail!("Expected matrix, got scalar."),
83+
}
84+
}
6785
}
6886

6987
impl<T: MatrixNumber> ToString for Type<T> {
@@ -106,25 +124,58 @@ impl<T: MatrixNumber> LaTeXable for Type<T> {
106124
}
107125
}
108126

127+
pub type Callable<T> = dyn Fn(Type<T>) -> anyhow::Result<Type<T>>;
128+
129+
fn builtin_functions<T: MatrixNumber>() -> BTreeMap<Identifier, Box<Callable<T>>> {
130+
BTreeMap::from([
131+
(
132+
Identifier::new_unsafe("transpose".to_string()),
133+
Box::new(|t: Type<T>| Ok(Type::Matrix(t.into_matrix()?.transpose())))
134+
as Box<Callable<T>>,
135+
),
136+
(
137+
Identifier::new_unsafe("identity".to_string()),
138+
Box::new(|t: Type<T>| {
139+
Ok(Type::Matrix(Matrix::identity(
140+
t.into_scalar()?
141+
.to_usize()
142+
.context("Invalid identity argument")?,
143+
)))
144+
}) as Box<Callable<T>>,
145+
),
146+
(
147+
Identifier::new_unsafe("inverse".to_string()),
148+
Box::new(|t: Type<T>| Ok(Type::Matrix(t.into_matrix()?.inverse()?.result)))
149+
as Box<Callable<T>>,
150+
),
151+
])
152+
}
153+
109154
pub struct Environment<T: MatrixNumber> {
110155
env: BTreeMap<Identifier, Type<T>>,
156+
fun: BTreeMap<Identifier, Box<Callable<T>>>,
111157
}
112158

113159
impl<T: MatrixNumber> Environment<T> {
114160
pub fn new() -> Self {
115161
Self {
116162
env: BTreeMap::new(),
163+
fun: builtin_functions(),
117164
}
118165
}
119166

120167
pub fn insert(&mut self, id: Identifier, value: Type<T>) {
121168
self.env.insert(id, value);
122169
}
123170

124-
pub fn get(&self, id: &Identifier) -> Option<&Type<T>> {
171+
pub fn get_value(&self, id: &Identifier) -> Option<&Type<T>> {
125172
self.env.get(id)
126173
}
127174

175+
pub fn get_function(&self, id: &Identifier) -> Option<&Callable<T>> {
176+
self.fun.get(id).map(|f| f.as_ref())
177+
}
178+
128179
pub fn iter_mut(&mut self) -> IterMut<'_, Identifier, Type<T>> {
129180
self.env.iter_mut()
130181
}
@@ -155,4 +206,19 @@ mod tests {
155206
assert!(matches!(Identifier::new("32".to_string()), Err(_)));
156207
assert!(matches!(Identifier::new("".to_string()), Err(_)));
157208
}
209+
210+
#[test]
211+
fn test_env_contains_std_fun() {
212+
let env = Environment::<i64>::new();
213+
214+
assert!(env
215+
.get_function(&Identifier::new_unsafe("transpose".to_string()))
216+
.is_some());
217+
assert!(env
218+
.get_function(&Identifier::new_unsafe("identity".to_string()))
219+
.is_some());
220+
assert!(env
221+
.get_function(&Identifier::new_unsafe("inverse".to_string()))
222+
.is_some());
223+
}
158224
}

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ impl<K: MatrixNumber> eframe::App for MatrixApp<K> {
187187
let mut windows_result = None;
188188
for (id, window) in self.state.windows.iter_mut() {
189189
if window.is_open {
190-
let element = self.state.env.get(id).unwrap();
190+
let element = self.state.env.get_value(id).unwrap();
191191
let local_result = display_env_element_window(
192192
ctx,
193193
(id, element),

src/parser.rs

Lines changed: 153 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,10 @@ impl<'a> Tokenizer<'a> {
7474
}
7575
}
7676

77-
#[derive(Debug, Clone, PartialEq, Eq)]
77+
#[derive(Clone, PartialEq, Eq)]
7878
enum WorkingToken<T: MatrixNumber> {
7979
Type(Type<T>),
80+
Function(Identifier),
8081
UnaryOp(char),
8182
BinaryOp(char),
8283
LeftBracket,
@@ -87,6 +88,7 @@ impl<T: MatrixNumber> Display for WorkingToken<T> {
8788
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
8889
match self {
8990
WorkingToken::Type(_) => write!(f, "value token"),
91+
WorkingToken::Function(_) => write!(f, "function token"),
9092
WorkingToken::UnaryOp(op) => write!(f, "unary operator \"{op}\""),
9193
WorkingToken::BinaryOp(op) => write!(f, "binary operator \"{op}\""),
9294
WorkingToken::LeftBracket => write!(f, "( bracket"),
@@ -114,23 +116,35 @@ fn binary_op<T: MatrixNumber>(left: Type<T>, right: Type<T>, op: char) -> anyhow
114116
(Type::Scalar(l), Type::Matrix(r)) => Type::from_matrix_result(r.checked_mul_scl(&l)),
115117
},
116118
'/' => match (left, right) {
117-
(Type::Scalar(l), Type::Scalar(r)) => if !r.is_zero() {
118-
Type::from_scalar_option(l.checked_div(&r))
119-
} else {
120-
bail!("Division by zero!")
121-
},
122-
(Type::Matrix(_), Type::Matrix(_)) => bail!("WTF dividing by matrix? You should use the `inv` function (not implemented yet, wait for it...)"),
123-
(Type::Matrix(_), Type::Scalar(_)) => bail!("Diving matrix by scalar is not supported yet..."),
124-
(Type::Scalar(_), Type::Matrix(_)) => bail!("Diving scalar by matrix does not make sense!"),
119+
(Type::Scalar(l), Type::Scalar(r)) => {
120+
if !r.is_zero() {
121+
Type::from_scalar_option(l.checked_div(&r))
122+
} else {
123+
bail!("Division by zero!")
124+
}
125+
}
126+
(Type::Matrix(_), Type::Matrix(_)) => {
127+
bail!("WTF dividing by matrix? You should use the `inverse` function instead!")
128+
}
129+
(Type::Matrix(_), Type::Scalar(_)) => {
130+
bail!("Diving matrix by scalar is not supported yet...")
131+
}
132+
(Type::Scalar(_), Type::Matrix(_)) => {
133+
bail!("Diving scalar by matrix does not make sense!")
134+
}
125135
},
126-
'^' => if let Type::Scalar(exp) = right {
127-
let exp = exp.to_usize().context("Exponent should be a nonnegative integer.")?;
128-
match left {
129-
Type::Scalar(base) => Type::from_scalar_option(checked_pow(base, exp)),
130-
Type::Matrix(base) => Type::from_matrix_result(base.checked_pow(exp)),
136+
'^' => {
137+
if let Type::Scalar(exp) = right {
138+
let exp = exp
139+
.to_usize()
140+
.context("Exponent should be a nonnegative integer.")?;
141+
match left {
142+
Type::Scalar(base) => Type::from_scalar_option(checked_pow(base, exp)),
143+
Type::Matrix(base) => Type::from_matrix_result(base.checked_pow(exp)),
144+
}
145+
} else {
146+
bail!("Exponent cannot be a matrix!");
131147
}
132-
} else {
133-
bail!("Exponent cannot be a matrix!");
134148
}
135149
_ => unimplemented!(),
136150
}
@@ -155,7 +169,7 @@ fn unary_op<T: MatrixNumber>(arg: Type<T>, op: char) -> anyhow::Result<Type<T>>
155169
<unary_op> ::= "+" | "-"
156170
<binary_op> ::= "+" | "-" | "*" | "/"
157171
<expr> ::= <integer> | <identifier> | <expr> <binary_op> <expr>
158-
| "(" <expr> ")" | <unary_op> <expr>
172+
| "(" <expr> ")" | <unary_op> <expr> | <identifier> "(" <expr> ")"
159173
*/
160174
pub fn parse_expression<T: MatrixNumber>(
161175
raw: &str,
@@ -185,6 +199,7 @@ pub fn parse_expression<T: MatrixNumber>(
185199
None | Some(WorkingToken::LeftBracket)
186200
| Some(WorkingToken::BinaryOp(_))
187201
| Some(WorkingToken::UnaryOp(_))
202+
| Some(WorkingToken::Function(_))
188203
),
189204
Token::Operator(_) => matches!(
190205
previous,
@@ -221,15 +236,18 @@ pub fn parse_expression<T: MatrixNumber>(
221236
outputs.back()
222237
}
223238
Token::Identifier(id) => {
224-
outputs.push_back(WorkingToken::Type(
225-
env.get(id)
226-
.context(format!(
227-
"Undefined identifier! Object \"{}\" is unknown.",
228-
id.to_string()
229-
))?
230-
.clone(),
231-
));
232-
outputs.back()
239+
if let Some(value) = env.get_value(id) {
240+
outputs.push_back(WorkingToken::Type(value.clone()));
241+
outputs.back()
242+
} else if env.get_function(id).is_some() {
243+
operators.push_front(WorkingToken::Function(id.clone()));
244+
operators.front()
245+
} else {
246+
bail!(
247+
"Undefined identifier! Object \"{}\" is unknown.",
248+
id.to_string()
249+
)
250+
}
233251
}
234252
Token::LeftBracket => {
235253
operators.push_front(WorkingToken::LeftBracket);
@@ -248,10 +266,11 @@ pub fn parse_expression<T: MatrixNumber>(
248266
bail!("Mismatched brackets!");
249267
}
250268
if let Some(op) = operators.pop_front() {
251-
if matches!(op, WorkingToken::UnaryOp(_)) {
252-
outputs.push_back(op);
253-
} else {
254-
operators.push_front(op);
269+
match op {
270+
WorkingToken::UnaryOp(_) | WorkingToken::Function(_) => {
271+
outputs.push_back(op)
272+
}
273+
_ => operators.push_front(op),
255274
}
256275
}
257276
Some(&WorkingToken::RightBracket)
@@ -312,6 +331,10 @@ pub fn parse_expression<T: MatrixNumber>(
312331
let arg = val_stack.pop_front().context("Invalid expression!")?;
313332
val_stack.push_front(unary_op(arg, op)?);
314333
}
334+
WorkingToken::Function(id) => {
335+
let arg = val_stack.pop_front().context("Invalid expression!")?;
336+
val_stack.push_front(env.get_function(&id).unwrap()(arg)?);
337+
}
315338
_ => unreachable!(),
316339
}
317340
}
@@ -546,7 +569,8 @@ mod tests {
546569
}
547570

548571
assert_eq!(
549-
*env.get(&Identifier::new("b".to_string()).unwrap()).unwrap(),
572+
*env.get_value(&Identifier::new("b".to_string()).unwrap())
573+
.unwrap(),
550574
Type::<i64>::Scalar(89)
551575
);
552576
}
@@ -561,8 +585,105 @@ mod tests {
561585
exec("a = $ ^ $");
562586

563587
assert_eq!(
564-
*env.get(&Identifier::new("a".to_string()).unwrap()).unwrap(),
588+
*env.get_value(&Identifier::new("a".to_string()).unwrap())
589+
.unwrap(),
565590
Type::<i64>::Scalar(256)
566591
);
567592
}
593+
594+
#[test]
595+
fn test_expression_functions() {
596+
let mut env = Environment::new();
597+
598+
let a = im![1, 2, 3; 4, 5, 6];
599+
let at = im![1, 4; 2, 5; 3, 6];
600+
let b = im![1, 2; 3, 4];
601+
602+
env.insert(Identifier::new("A".to_string()).unwrap(), Type::Matrix(a));
603+
env.insert(
604+
Identifier::new("B".to_string()).unwrap(),
605+
Type::Matrix(b.clone()),
606+
);
607+
608+
assert_eq!(
609+
parse_expression("transpose(A)", &env).unwrap(),
610+
Type::Matrix(at)
611+
);
612+
assert_eq!(
613+
parse_expression("identity(4)", &env).unwrap(),
614+
Type::Matrix(Matrix::identity(4))
615+
);
616+
assert_eq!(
617+
parse_expression("inverse(B)", &env).unwrap(),
618+
Type::Matrix(b.inverse().unwrap().result)
619+
);
620+
}
621+
622+
#[test]
623+
fn test_nested_functions() {
624+
let mut env = Environment::new();
625+
626+
let a = im![1, 2, 3; 4, 5, 6];
627+
let att = im![1, 2, 3; 4, 5, 6];
628+
629+
env.insert(Identifier::new("A".to_string()).unwrap(), Type::Matrix(a));
630+
631+
assert_eq!(
632+
parse_expression("transpose(transpose(A))", &env).unwrap(),
633+
Type::Matrix(att)
634+
)
635+
}
636+
637+
#[test]
638+
fn test_expr_with_function() {
639+
let mut env = Environment::new();
640+
641+
let a = im![1, 2, 3; 4, 5, 6];
642+
let b = im![1, 2; 3, 4];
643+
644+
env.insert(Identifier::new("A".to_string()).unwrap(), Type::Matrix(a));
645+
env.insert(
646+
Identifier::new("B".to_string()).unwrap(),
647+
Type::Matrix(b.clone()),
648+
);
649+
650+
assert_eq!(
651+
parse_expression("transpose(A) * B", &env).unwrap(),
652+
Type::Matrix(im![13, 18; 17, 24; 21, 30])
653+
);
654+
}
655+
656+
#[test]
657+
fn test_expr_in_function() {
658+
let mut env = Environment::new();
659+
660+
let a = im![1, 2, 3; 4, 5, 6];
661+
let i = Matrix::identity(2);
662+
let at = im![1, 4; 2, 5; 3, 6];
663+
664+
env.insert(Identifier::new("A".to_string()).unwrap(), Type::Matrix(a));
665+
env.insert(Identifier::new("I".to_string()).unwrap(), Type::Matrix(i));
666+
667+
assert_eq!(
668+
parse_expression("transpose(I * A)", &env).unwrap(),
669+
Type::Matrix(at)
670+
);
671+
}
672+
673+
#[test]
674+
fn test_complex_nested_function_with_expr() {
675+
let mut env = Environment::new();
676+
677+
let a = im![1, 2, 3; 4, 5, 6];
678+
679+
env.insert(Identifier::new("A".to_string()).unwrap(), Type::Matrix(a));
680+
681+
assert_eq!(
682+
parse_expression(
683+
"transpose(transpose(identity(2137 - 2135 + 1 - 1 + (42 - 420) * 0) * A) + transpose(identity(2) * A))",
684+
&env
685+
).unwrap(),
686+
Type::Matrix(im![2, 4, 6; 8, 10, 12])
687+
);
688+
}
568689
}

0 commit comments

Comments
 (0)