Skip to content

Commit 775dc16

Browse files
committed
Final touches
1 parent ff2a057 commit 775dc16

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ modeling:
2222
stratify_by: "neighbourhood_group"
2323
# Maximum number of features to consider for the TFIDF applied to the title of the
2424
# insertion (the column called "name")
25-
max_tfidf_features: 30
25+
max_tfidf_features: 5
2626
# NOTE: you can put here any parameter that is accepted by the constructor of
2727
# RandomForestRegressor. This is a subsample, but more could be added:
2828
random_forest:

src/train_random_forest/run.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ def go(args):
4848
rf_config = json.load(fp)
4949
run.config.update(rf_config)
5050

51+
# Fix the random seed for the Random Forest, so we get reproducible results
52+
rf_config['random_state'] = args.random_seed
53+
5154
######################################
5255
# Use run.use_artifact(...).file() to get the train and validation artifact (args.trainval_artifact)
5356
# and save the returned path in train_local_pat
@@ -60,7 +63,7 @@ def go(args):
6063
logger.info(f"Minimum price: {y.min()}, Maximum price: {y.max()}")
6164

6265
X_train, X_val, y_train, y_val = train_test_split(
63-
X, y, test_size=args.val_size, stratify=X[args.stratify_by]
66+
X, y, test_size=args.val_size, stratify=X[args.stratify_by], random_state=args.random_seed
6467
)
6568

6669
logger.info("Preparing sklearn pipeline")

0 commit comments

Comments
 (0)