|
12 | 12 | from sklearn.compose import ColumnTransformer
|
13 | 13 |
|
14 | 14 |
|
15 |
| -# ------------------------------------------------------------------------------ |
16 |
| - |
17 |
| -# DATA |
| 15 | +# --------------------------------------------------------------------------------------------- DATA |
18 | 16 |
|
19 | 17 | # Raw values
|
20 | 18 | # - Index = 0 is also expected output
|
|
29 | 27 | ])
|
30 | 28 |
|
31 | 29 |
|
32 |
| -# ------------------------------------------------------------------------------ |
33 |
| - |
34 |
| -# TRANSFORM |
| 30 | +# ---------------------------------------------------------------------------------------- TRANSFORM |
35 | 31 |
|
36 | 32 | # - Raw data index 1-3 (ignore index 0)
|
37 | 33 | # - Raw data index 0
|
38 | 34 | inputs = data[:, 1:]
|
39 | 35 | outputs = data[:, 0].astype(int)
|
40 |
| -#print(inputs) |
41 |
| -#print(outputs) |
42 | 36 |
|
43 | 37 |
|
44 | 38 | # Strings become integers (here: Asteroid=0, Particle=1)
|
|
52 | 46 | transformers=[("encoder", OneHotEncoder(), [0])],
|
53 | 47 | remainder="passthrough"
|
54 | 48 | )
|
55 |
| -inputs = np.array(ct.fit_transform(inputs), dtype=float) |
56 |
| -#print(inputs) |
57 | 49 |
|
| 50 | +inputs = np.array(ct.fit_transform(inputs), dtype=float) |
58 | 51 |
|
59 |
| -# ------------------------------------------------------------------------------ |
60 | 52 |
|
61 |
| -# MODEL |
| 53 | +# -------------------------------------------------------------------------------------------- MODEL |
62 | 54 |
|
63 | 55 | # Hidden layers = Between in/out, not visible/usable?
|
64 | 56 | # - 8 neurons with 4 inputs
|
|
78 | 70 | model.fit(inputs, outputs, epochs=1000, batch_size=2)
|
79 | 71 |
|
80 | 72 |
|
81 |
| -# ------------------------------------------------------------------------------ |
82 |
| -# PREDICTIONS |
| 73 | +# -------------------------------------------------------------------------------------- PREDICTIONS |
83 | 74 |
|
84 | 75 | def predict_collision(category: str, pos_x_item: int, pos_x_player: int):
|
85 | 76 | # Convert the input data into a NumPy array
|
86 | 77 | input_data = np.array([[category, pos_x_item, pos_x_player]])
|
87 |
| - #print(input_data) |
88 | 78 |
|
89 | 79 | # Transform the input data using the ColumnTransformer
|
90 | 80 | input_data_transformed = ct.transform(input_data).astype(float)
|
91 |
| - #print(input_data_transformed) |
92 | 81 |
|
93 | 82 | # Use the model to predict the probability of a collision
|
94 | 83 | prediction = model.predict(input_data_transformed)
|
95 |
| - #print(prediction) |
96 | 84 |
|
97 | 85 | return prediction[0][0] > 0.5
|
98 | 86 |
|
99 | 87 |
|
100 |
| -# ------------------------------------------------------------------------------ |
101 |
| -# TESTS |
| 88 | +# -------------------------------------------------------------------------------------------- TESTS |
102 | 89 |
|
103 | 90 | # Test model with examples
|
104 |
| -print(predict_collision("ASTEROID", 100, 100)) # Expected output: True |
105 |
| -print(predict_collision("ASTEROID", 0, 100)) # Expected output: False |
106 |
| -print(predict_collision("ASTEROID", 100, 0)) # Expected output: False |
107 |
| -print(predict_collision("PARTICLE", 100, 100)) # Expected output: False |
108 |
| -print(predict_collision("PARTICLE", 0, 100)) # Expected output: False |
109 |
| -print(predict_collision("PARTICLE", 100, 0)) # Expected output: False |
110 |
| -print(predict_collision("ASTEROID", 0, 0)) # Expected output: True |
| 91 | +print(predict_collision("ASTEROID", 100, 100)) # Expected output: True |
| 92 | +print(predict_collision("ASTEROID", 0, 100)) # Expected output: False |
| 93 | +print(predict_collision("ASTEROID", 100, 0)) # Expected output: False |
| 94 | +print(predict_collision("PARTICLE", 100, 100)) # Expected output: False |
| 95 | +print(predict_collision("PARTICLE", 0, 100)) # Expected output: False |
| 96 | +print(predict_collision("PARTICLE", 100, 0)) # Expected output: False |
| 97 | +print(predict_collision("ASTEROID", 0, 0)) # Expected output: True |
0 commit comments