11//! The actual token generator for the macro 
2- use  quote:: quote; 
3- use  syn:: { punctuated:: Punctuated ,  token:: Comma ,  Ident ,  ItemEnum ,  LitStr ,  Variant } ; 
2+ 
3+ use  { 
4+     crate :: parser:: SplProgramErrorArgs , 
5+     proc_macro2:: Span , 
6+     quote:: quote, 
7+     syn:: { 
8+         punctuated:: Punctuated ,  token:: Comma ,  Expr ,  ExprLit ,  Ident ,  ItemEnum ,  Lit ,  LitInt ,  LitStr , 
9+         Token ,  Variant , 
10+     } , 
11+ } ; 
12+ 
13+ const  SPL_ERROR_HASH_NAMESPACE :  & str  = "spl_program_error" ; 
14+ const  SPL_ERROR_HASH_MIN_VALUE :  u32  = 7_000 ; 
415
516/// The type of macro being called, thus directing which tokens to generate 
617#[ allow( clippy:: enum_variant_names) ]  
718pub  enum  MacroType  { 
8-     IntoProgramError , 
9-     DecodeError , 
10-     PrintProgramError , 
11-     SplProgramError , 
19+     IntoProgramError  { 
20+         ident :  Ident , 
21+     } , 
22+     DecodeError  { 
23+         ident :  Ident , 
24+     } , 
25+     PrintProgramError  { 
26+         ident :  Ident , 
27+         variants :  Punctuated < Variant ,  Comma > , 
28+     } , 
29+     SplProgramError  { 
30+         args :  SplProgramErrorArgs , 
31+         item_enum :  ItemEnum , 
32+     } , 
1233} 
1334
1435impl  MacroType  { 
1536    /// Generates the corresponding tokens based on variant selection 
16-      pub  fn  generate_tokens ( & self ,   item_enum :   ItemEnum )  -> proc_macro2:: TokenStream  { 
37+      pub  fn  generate_tokens ( & mut   self )  -> proc_macro2:: TokenStream  { 
1738        match  self  { 
18-             MacroType :: IntoProgramError  => into_program_error ( & item_enum. ident ) , 
19-             MacroType :: DecodeError  => decode_error ( & item_enum. ident ) , 
20-             MacroType :: PrintProgramError  => { 
21-                 print_program_error ( & item_enum. ident ,  & item_enum. variants ) 
22-             } 
23-             MacroType :: SplProgramError  => spl_program_error ( item_enum) , 
39+             Self :: IntoProgramError  {  ident }  => into_program_error ( ident) , 
40+             Self :: DecodeError  {  ident }  => decode_error ( ident) , 
41+             Self :: PrintProgramError  {  ident,  variants }  => print_program_error ( ident,  variants) , 
42+             Self :: SplProgramError  {  args,  item_enum }  => spl_program_error ( args,  item_enum) , 
2443        } 
2544    } 
2645} 
2746
28- /// Builds the implementation of `Into<solana_program::program_error::ProgramError>` 
29- /// More specifically, implements `From<Self> for solana_program::program_error::ProgramError` 
47+ /// Builds the implementation of 
48+ /// `Into<solana_program::program_error::ProgramError>` More specifically, 
49+ /// implements `From<Self> for solana_program::program_error::ProgramError` 
3050pub  fn  into_program_error ( ident :  & Ident )  -> proc_macro2:: TokenStream  { 
3151    quote !  { 
3252        impl  From <#ident> for  solana_program:: program_error:: ProgramError  { 
@@ -48,7 +68,8 @@ pub fn decode_error(ident: &Ident) -> proc_macro2::TokenStream {
4868    } 
4969} 
5070
51- /// Builds the implementation of `solana_program::program_error::PrintProgramError` 
71+ /// Builds the implementation of 
72+ /// `solana_program::program_error::PrintProgramError` 
5273pub  fn  print_program_error ( 
5374    ident :  & Ident , 
5475    variants :  & Punctuated < Variant ,  Comma > , 
@@ -96,16 +117,25 @@ fn get_error_message(variant: &Variant) -> Option<String> {
96117
97118/// The main function that produces the tokens required to turn your 
98119/// error enum into a Solana Program Error 
99- pub  fn  spl_program_error ( input :  ItemEnum )  -> proc_macro2:: TokenStream  { 
100-     let  ident = & input. ident ; 
101-     let  variants = & input. variants ; 
120+ pub  fn  spl_program_error ( 
121+     args :  & SplProgramErrorArgs , 
122+     item_enum :  & mut  ItemEnum , 
123+ )  -> proc_macro2:: TokenStream  { 
124+     if  let  Some ( error_code_start)  = args. hash_error_code_start  { 
125+         set_first_discriminant ( item_enum,  error_code_start) ; 
126+     } 
127+ 
128+     let  ident = & item_enum. ident ; 
129+     let  variants = & item_enum. variants ; 
102130    let  into_program_error = into_program_error ( ident) ; 
103131    let  decode_error = decode_error ( ident) ; 
104132    let  print_program_error = print_program_error ( ident,  variants) ; 
133+ 
105134    quote !  { 
135+         #[ repr( u32 ) ] 
106136        #[ derive( Clone ,  Debug ,  Eq ,  thiserror:: Error ,  num_derive:: FromPrimitive ,  PartialEq ) ] 
107137        #[ num_traits = "num_traits" ] 
108-         #input 
138+         #item_enum 
109139
110140        #into_program_error
111141
@@ -114,3 +144,55 @@ pub fn spl_program_error(input: ItemEnum) -> proc_macro2::TokenStream {
114144        #print_program_error
115145    } 
116146} 
147+ 
148+ /// This function adds a discriminant to the first enum variant based on the 
149+ /// hash of the `SPL_ERROR_HASH_NAMESPACE` constant, the enum name and variant 
150+ /// name. 
151+ /// It will then check to make sure the provided `hash_error_code_start` is 
152+ /// equal to the hash-produced `u32`. 
153+ /// 
154+ /// See https://docs.rs/syn/latest/syn/struct.Variant.html 
155+ fn  set_first_discriminant ( item_enum :  & mut  ItemEnum ,  error_code_start :  u32 )  { 
156+     let  enum_ident = & item_enum. ident ; 
157+     if  item_enum. variants . is_empty ( )  { 
158+         panic ! ( "Enum must have at least one variant" ) ; 
159+     } 
160+     let  first_variant = & mut  item_enum. variants [ 0 ] ; 
161+     let  discriminant = u32_from_hash ( enum_ident) ; 
162+     if  discriminant == error_code_start { 
163+         let  eq = Token ! [ =] ( Span :: call_site ( ) ) ; 
164+         let  expr = Expr :: Lit ( ExprLit  { 
165+             attrs :  Vec :: new ( ) , 
166+             lit :  Lit :: Int ( LitInt :: new ( & discriminant. to_string ( ) ,  Span :: call_site ( ) ) ) , 
167+         } ) ; 
168+         first_variant. discriminant  = Some ( ( eq,  expr) ) ; 
169+     }  else  { 
170+         panic ! ( 
171+             "Error code start value from hash must be {0}. Update your macro attribute to \  
172+               `#[spl_program_error(hash_error_code_start = {0})]`.", 
173+             discriminant
174+         ) ; 
175+     } 
176+ } 
177+ 
178+ /// Hashes the `SPL_ERROR_HASH_NAMESPACE` constant, the enum name and variant 
179+ /// name and returns four middle bytes (13 through 16) as a u32. 
180+ fn  u32_from_hash ( enum_ident :  & Ident )  -> u32  { 
181+     let  hash_input = format ! ( "{}:{}" ,  SPL_ERROR_HASH_NAMESPACE ,  enum_ident) ; 
182+ 
183+     // We don't want our error code to start at any number below 
184+     // `SPL_ERROR_HASH_MIN_VALUE`! 
185+     let  mut  nonce:  u32  = 0 ; 
186+     loop  { 
187+         let  hash = solana_program:: hash:: hashv ( & [ hash_input. as_bytes ( ) ,  & nonce. to_le_bytes ( ) ] ) ; 
188+         let  d = u32:: from_le_bytes ( 
189+             hash. to_bytes ( ) [ 13 ..17 ] 
190+                 . try_into ( ) 
191+                 . expect ( "Unable to convert hash to u32" ) , 
192+         ) ; 
193+         if  d >= SPL_ERROR_HASH_MIN_VALUE  { 
194+             return  d; 
195+         } 
196+         nonce += 1 ; 
197+     } 
198+ } 
0 commit comments