@@ -74,9 +74,10 @@ impl<'a> Tokenizer<'a> {
74
74
}
75
75
}
76
76
77
- #[ derive( Debug , Clone , PartialEq , Eq ) ]
77
+ #[ derive( Clone , PartialEq , Eq ) ]
78
78
enum WorkingToken < T : MatrixNumber > {
79
79
Type ( Type < T > ) ,
80
+ Function ( Identifier ) ,
80
81
UnaryOp ( char ) ,
81
82
BinaryOp ( char ) ,
82
83
LeftBracket ,
@@ -87,6 +88,7 @@ impl<T: MatrixNumber> Display for WorkingToken<T> {
87
88
fn fmt ( & self , f : & mut Formatter < ' _ > ) -> std:: fmt:: Result {
88
89
match self {
89
90
WorkingToken :: Type ( _) => write ! ( f, "value token" ) ,
91
+ WorkingToken :: Function ( _) => write ! ( f, "function token" ) ,
90
92
WorkingToken :: UnaryOp ( op) => write ! ( f, "unary operator \" {op}\" " ) ,
91
93
WorkingToken :: BinaryOp ( op) => write ! ( f, "binary operator \" {op}\" " ) ,
92
94
WorkingToken :: LeftBracket => write ! ( f, "( bracket" ) ,
@@ -114,23 +116,35 @@ fn binary_op<T: MatrixNumber>(left: Type<T>, right: Type<T>, op: char) -> anyhow
114
116
( Type :: Scalar ( l) , Type :: Matrix ( r) ) => Type :: from_matrix_result ( r. checked_mul_scl ( & l) ) ,
115
117
} ,
116
118
'/' => match ( left, right) {
117
- ( Type :: Scalar ( l) , Type :: Scalar ( r) ) => if !r. is_zero ( ) {
118
- Type :: from_scalar_option ( l. checked_div ( & r) )
119
- } else {
120
- bail ! ( "Division by zero!" )
121
- } ,
122
- ( Type :: Matrix ( _) , Type :: Matrix ( _) ) => bail ! ( "WTF dividing by matrix? You should use the `inv` function (not implemented yet, wait for it...)" ) ,
123
- ( Type :: Matrix ( _) , Type :: Scalar ( _) ) => bail ! ( "Diving matrix by scalar is not supported yet..." ) ,
124
- ( Type :: Scalar ( _) , Type :: Matrix ( _) ) => bail ! ( "Diving scalar by matrix does not make sense!" ) ,
119
+ ( Type :: Scalar ( l) , Type :: Scalar ( r) ) => {
120
+ if !r. is_zero ( ) {
121
+ Type :: from_scalar_option ( l. checked_div ( & r) )
122
+ } else {
123
+ bail ! ( "Division by zero!" )
124
+ }
125
+ }
126
+ ( Type :: Matrix ( _) , Type :: Matrix ( _) ) => {
127
+ bail ! ( "WTF dividing by matrix? You should use the `inverse` function instead!" )
128
+ }
129
+ ( Type :: Matrix ( _) , Type :: Scalar ( _) ) => {
130
+ bail ! ( "Diving matrix by scalar is not supported yet..." )
131
+ }
132
+ ( Type :: Scalar ( _) , Type :: Matrix ( _) ) => {
133
+ bail ! ( "Diving scalar by matrix does not make sense!" )
134
+ }
125
135
} ,
126
- '^' => if let Type :: Scalar ( exp) = right {
127
- let exp = exp. to_usize ( ) . context ( "Exponent should be a nonnegative integer." ) ?;
128
- match left {
129
- Type :: Scalar ( base) => Type :: from_scalar_option ( checked_pow ( base, exp) ) ,
130
- Type :: Matrix ( base) => Type :: from_matrix_result ( base. checked_pow ( exp) ) ,
136
+ '^' => {
137
+ if let Type :: Scalar ( exp) = right {
138
+ let exp = exp
139
+ . to_usize ( )
140
+ . context ( "Exponent should be a nonnegative integer." ) ?;
141
+ match left {
142
+ Type :: Scalar ( base) => Type :: from_scalar_option ( checked_pow ( base, exp) ) ,
143
+ Type :: Matrix ( base) => Type :: from_matrix_result ( base. checked_pow ( exp) ) ,
144
+ }
145
+ } else {
146
+ bail ! ( "Exponent cannot be a matrix!" ) ;
131
147
}
132
- } else {
133
- bail ! ( "Exponent cannot be a matrix!" ) ;
134
148
}
135
149
_ => unimplemented ! ( ) ,
136
150
}
@@ -155,7 +169,7 @@ fn unary_op<T: MatrixNumber>(arg: Type<T>, op: char) -> anyhow::Result<Type<T>>
155
169
<unary_op> ::= "+" | "-"
156
170
<binary_op> ::= "+" | "-" | "*" | "/"
157
171
<expr> ::= <integer> | <identifier> | <expr> <binary_op> <expr>
158
- | "(" <expr> ")" | <unary_op> <expr>
172
+ | "(" <expr> ")" | <unary_op> <expr> | <identifier> "(" <expr> ")"
159
173
*/
160
174
pub fn parse_expression < T : MatrixNumber > (
161
175
raw : & str ,
@@ -185,6 +199,7 @@ pub fn parse_expression<T: MatrixNumber>(
185
199
None | Some ( WorkingToken :: LeftBracket )
186
200
| Some ( WorkingToken :: BinaryOp ( _) )
187
201
| Some ( WorkingToken :: UnaryOp ( _) )
202
+ | Some ( WorkingToken :: Function ( _) )
188
203
) ,
189
204
Token :: Operator ( _) => matches ! (
190
205
previous,
@@ -221,15 +236,18 @@ pub fn parse_expression<T: MatrixNumber>(
221
236
outputs. back ( )
222
237
}
223
238
Token :: Identifier ( id) => {
224
- outputs. push_back ( WorkingToken :: Type (
225
- env. get ( id)
226
- . context ( format ! (
227
- "Undefined identifier! Object \" {}\" is unknown." ,
228
- id. to_string( )
229
- ) ) ?
230
- . clone ( ) ,
231
- ) ) ;
232
- outputs. back ( )
239
+ if let Some ( value) = env. get_value ( id) {
240
+ outputs. push_back ( WorkingToken :: Type ( value. clone ( ) ) ) ;
241
+ outputs. back ( )
242
+ } else if env. get_function ( id) . is_some ( ) {
243
+ operators. push_front ( WorkingToken :: Function ( id. clone ( ) ) ) ;
244
+ operators. front ( )
245
+ } else {
246
+ bail ! (
247
+ "Undefined identifier! Object \" {}\" is unknown." ,
248
+ id. to_string( )
249
+ )
250
+ }
233
251
}
234
252
Token :: LeftBracket => {
235
253
operators. push_front ( WorkingToken :: LeftBracket ) ;
@@ -248,10 +266,11 @@ pub fn parse_expression<T: MatrixNumber>(
248
266
bail ! ( "Mismatched brackets!" ) ;
249
267
}
250
268
if let Some ( op) = operators. pop_front ( ) {
251
- if matches ! ( op, WorkingToken :: UnaryOp ( _) ) {
252
- outputs. push_back ( op) ;
253
- } else {
254
- operators. push_front ( op) ;
269
+ match op {
270
+ WorkingToken :: UnaryOp ( _) | WorkingToken :: Function ( _) => {
271
+ outputs. push_back ( op)
272
+ }
273
+ _ => operators. push_front ( op) ,
255
274
}
256
275
}
257
276
Some ( & WorkingToken :: RightBracket )
@@ -312,6 +331,10 @@ pub fn parse_expression<T: MatrixNumber>(
312
331
let arg = val_stack. pop_front ( ) . context ( "Invalid expression!" ) ?;
313
332
val_stack. push_front ( unary_op ( arg, op) ?) ;
314
333
}
334
+ WorkingToken :: Function ( id) => {
335
+ let arg = val_stack. pop_front ( ) . context ( "Invalid expression!" ) ?;
336
+ val_stack. push_front ( env. get_function ( & id) . unwrap ( ) ( arg) ?) ;
337
+ }
315
338
_ => unreachable ! ( ) ,
316
339
}
317
340
}
@@ -546,7 +569,8 @@ mod tests {
546
569
}
547
570
548
571
assert_eq ! (
549
- * env. get( & Identifier :: new( "b" . to_string( ) ) . unwrap( ) ) . unwrap( ) ,
572
+ * env. get_value( & Identifier :: new( "b" . to_string( ) ) . unwrap( ) )
573
+ . unwrap( ) ,
550
574
Type :: <i64 >:: Scalar ( 89 )
551
575
) ;
552
576
}
@@ -561,8 +585,105 @@ mod tests {
561
585
exec ( "a = $ ^ $" ) ;
562
586
563
587
assert_eq ! (
564
- * env. get( & Identifier :: new( "a" . to_string( ) ) . unwrap( ) ) . unwrap( ) ,
588
+ * env. get_value( & Identifier :: new( "a" . to_string( ) ) . unwrap( ) )
589
+ . unwrap( ) ,
565
590
Type :: <i64 >:: Scalar ( 256 )
566
591
) ;
567
592
}
593
+
594
+ #[ test]
595
+ fn test_expression_functions ( ) {
596
+ let mut env = Environment :: new ( ) ;
597
+
598
+ let a = im ! [ 1 , 2 , 3 ; 4 , 5 , 6 ] ;
599
+ let at = im ! [ 1 , 4 ; 2 , 5 ; 3 , 6 ] ;
600
+ let b = im ! [ 1 , 2 ; 3 , 4 ] ;
601
+
602
+ env. insert ( Identifier :: new ( "A" . to_string ( ) ) . unwrap ( ) , Type :: Matrix ( a) ) ;
603
+ env. insert (
604
+ Identifier :: new ( "B" . to_string ( ) ) . unwrap ( ) ,
605
+ Type :: Matrix ( b. clone ( ) ) ,
606
+ ) ;
607
+
608
+ assert_eq ! (
609
+ parse_expression( "transpose(A)" , & env) . unwrap( ) ,
610
+ Type :: Matrix ( at)
611
+ ) ;
612
+ assert_eq ! (
613
+ parse_expression( "identity(4)" , & env) . unwrap( ) ,
614
+ Type :: Matrix ( Matrix :: identity( 4 ) )
615
+ ) ;
616
+ assert_eq ! (
617
+ parse_expression( "inverse(B)" , & env) . unwrap( ) ,
618
+ Type :: Matrix ( b. inverse( ) . unwrap( ) . result)
619
+ ) ;
620
+ }
621
+
622
+ #[ test]
623
+ fn test_nested_functions ( ) {
624
+ let mut env = Environment :: new ( ) ;
625
+
626
+ let a = im ! [ 1 , 2 , 3 ; 4 , 5 , 6 ] ;
627
+ let att = im ! [ 1 , 2 , 3 ; 4 , 5 , 6 ] ;
628
+
629
+ env. insert ( Identifier :: new ( "A" . to_string ( ) ) . unwrap ( ) , Type :: Matrix ( a) ) ;
630
+
631
+ assert_eq ! (
632
+ parse_expression( "transpose(transpose(A))" , & env) . unwrap( ) ,
633
+ Type :: Matrix ( att)
634
+ )
635
+ }
636
+
637
+ #[ test]
638
+ fn test_expr_with_function ( ) {
639
+ let mut env = Environment :: new ( ) ;
640
+
641
+ let a = im ! [ 1 , 2 , 3 ; 4 , 5 , 6 ] ;
642
+ let b = im ! [ 1 , 2 ; 3 , 4 ] ;
643
+
644
+ env. insert ( Identifier :: new ( "A" . to_string ( ) ) . unwrap ( ) , Type :: Matrix ( a) ) ;
645
+ env. insert (
646
+ Identifier :: new ( "B" . to_string ( ) ) . unwrap ( ) ,
647
+ Type :: Matrix ( b. clone ( ) ) ,
648
+ ) ;
649
+
650
+ assert_eq ! (
651
+ parse_expression( "transpose(A) * B" , & env) . unwrap( ) ,
652
+ Type :: Matrix ( im![ 13 , 18 ; 17 , 24 ; 21 , 30 ] )
653
+ ) ;
654
+ }
655
+
656
+ #[ test]
657
+ fn test_expr_in_function ( ) {
658
+ let mut env = Environment :: new ( ) ;
659
+
660
+ let a = im ! [ 1 , 2 , 3 ; 4 , 5 , 6 ] ;
661
+ let i = Matrix :: identity ( 2 ) ;
662
+ let at = im ! [ 1 , 4 ; 2 , 5 ; 3 , 6 ] ;
663
+
664
+ env. insert ( Identifier :: new ( "A" . to_string ( ) ) . unwrap ( ) , Type :: Matrix ( a) ) ;
665
+ env. insert ( Identifier :: new ( "I" . to_string ( ) ) . unwrap ( ) , Type :: Matrix ( i) ) ;
666
+
667
+ assert_eq ! (
668
+ parse_expression( "transpose(I * A)" , & env) . unwrap( ) ,
669
+ Type :: Matrix ( at)
670
+ ) ;
671
+ }
672
+
673
+ #[ test]
674
+ fn test_complex_nested_function_with_expr ( ) {
675
+ let mut env = Environment :: new ( ) ;
676
+
677
+ let a = im ! [ 1 , 2 , 3 ; 4 , 5 , 6 ] ;
678
+
679
+ env. insert ( Identifier :: new ( "A" . to_string ( ) ) . unwrap ( ) , Type :: Matrix ( a) ) ;
680
+
681
+ assert_eq ! (
682
+ parse_expression(
683
+ "transpose(transpose(identity(2137 - 2135 + 1 - 1 + (42 - 420) * 0) * A) + transpose(identity(2) * A))" ,
684
+ & env
685
+ ) . unwrap( ) ,
686
+ Type :: Matrix ( im![ 2 , 4 , 6 ; 8 , 10 , 12 ] )
687
+ ) ;
688
+ }
568
689
}
0 commit comments