@@ -58,12 +58,14 @@ fn train_decision_tree(data: &Array2<f32>, sq: SplitQuality) -> DecisionTree<f32
58
58
}
59
59
60
60
fn save_model ( model : & DecisionTree < f32 , String > , file_name : & str ) {
61
- let model_file = File :: create ( file_name) . unwrap ( ) ;
61
+ let base_path = env:: var ( "CONFIG_PATH" ) . unwrap_or ( String :: new ( ) ) ;
62
+ let model_file = File :: create ( format ! ( "{}{}" , base_path, file_name) ) . unwrap ( ) ;
62
63
bincode:: serialize_into ( & model_file, & model) . unwrap ( ) ;
63
64
}
64
65
65
66
fn load_model ( file_name : & str ) -> Option < DecisionTree < f32 , String > > {
66
- if let Ok ( mut model_file) = File :: open ( file_name) {
67
+ let base_path = env:: var ( "CONFIG_PATH" ) . unwrap_or ( String :: new ( ) ) ;
68
+ if let Ok ( mut model_file) = File :: open ( format ! ( "{}{}" , base_path, file_name) ) {
67
69
let mut buffer = Vec :: new ( ) ;
68
70
model_file. read_to_end ( & mut buffer) . unwrap ( ) ;
69
71
let deserialized_model: DecisionTree < f32 , String > = bincode:: deserialize ( & buffer) . unwrap ( ) ;
@@ -86,7 +88,8 @@ fn load_model(file_name: &str) -> Option<DecisionTree<f32, String>> {
86
88
pub fn recreate_model ( game_count : usize ) {
87
89
let token: & str = & env:: var ( "RIOT_TOKEN" )
88
90
. expect ( "Could not fetch the Riot token" ) ;
89
- let snowflake_map = create_snowflake_puuid_map ( crate :: constants:: MAPPING_FILE ) ;
91
+ let base_path = env:: var ( "CONFIG_PATH" ) . unwrap_or ( String :: new ( ) ) ;
92
+ let snowflake_map = create_snowflake_puuid_map ( & format ! ( "{}{}" , base_path, crate :: constants:: MAPPING_FILE ) ) ;
90
93
let test_puuid = snowflake_map. values ( ) . filter ( |id| id. starts_with ( "f7Xz" ) ) . nth ( 0 ) . unwrap ( ) . clone ( ) ;
91
94
92
95
// Train a new model
@@ -107,7 +110,8 @@ pub fn recreate_model(game_count: usize) {
107
110
pub fn predict ( snowflake : & str ) -> Option < String > {
108
111
let token: & str = & env:: var ( "RIOT_TOKEN" )
109
112
. expect ( "Could not fetch the Riot token" ) ;
110
- let snowflake_map = create_snowflake_puuid_map ( crate :: constants:: MAPPING_FILE ) ;
113
+ let base_path = env:: var ( "CONFIG_PATH" ) . unwrap_or ( String :: new ( ) ) ;
114
+ let snowflake_map = create_snowflake_puuid_map ( & format ! ( "{}{}" , base_path, crate :: constants:: MAPPING_FILE ) ) ;
111
115
112
116
if let Some ( puuid) = snowflake_map. get ( snowflake) {
113
117
if let Some ( model) = load_model ( MODEL_FILE_NAME ) {
0 commit comments