|
| 1 | +# /usr/bin/env python3 |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | + |
| 4 | +from ot.backend import get_backend_list, jax, tf |
| 5 | +import gc |
| 6 | + |
| 7 | + |
| 8 | +def setup_backends(): |
| 9 | + if jax: |
| 10 | + from jax.config import config |
| 11 | + config.update("jax_enable_x64", True) |
| 12 | + |
| 13 | + if tf: |
| 14 | + from tensorflow.python.ops.numpy_ops import np_config |
| 15 | + np_config.enable_numpy_behavior() |
| 16 | + |
| 17 | + |
| 18 | +def exec_bench(setup, tested_function, param_list, n_runs, warmup_runs): |
| 19 | + backend_list = get_backend_list() |
| 20 | + for i, nx in enumerate(backend_list): |
| 21 | + if nx.__name__ == "tf" and i < len(backend_list) - 1: |
| 22 | + # Tensorflow should be the last one to be benchmarked because |
| 23 | + # as far as I'm aware, there is no way to force it to release |
| 24 | + # GPU memory. Hence, if any other backend is benchmarked after |
| 25 | + # Tensorflow and requires the usage of a GPU, it will not have the |
| 26 | + # full memory available and you may have a GPU Out Of Memory error |
| 27 | + # even though your GPU can technically hold your tensors in memory. |
| 28 | + backend_list.pop(i) |
| 29 | + backend_list.append(nx) |
| 30 | + break |
| 31 | + |
| 32 | + inputs = [setup(param) for param in param_list] |
| 33 | + results = dict() |
| 34 | + for nx in backend_list: |
| 35 | + for i in range(len(param_list)): |
| 36 | + print(nx, param_list[i]) |
| 37 | + args = inputs[i] |
| 38 | + results_nx = nx._bench( |
| 39 | + tested_function, |
| 40 | + *args, |
| 41 | + n_runs=n_runs, |
| 42 | + warmup_runs=warmup_runs |
| 43 | + ) |
| 44 | + gc.collect() |
| 45 | + results_nx_with_param_in_key = dict() |
| 46 | + for key in results_nx: |
| 47 | + new_key = (param_list[i], *key) |
| 48 | + results_nx_with_param_in_key[new_key] = results_nx[key] |
| 49 | + results.update(results_nx_with_param_in_key) |
| 50 | + return results |
| 51 | + |
| 52 | + |
| 53 | +def convert_to_html_table(results, param_name, main_title=None, comments=None): |
| 54 | + string = "<table>\n" |
| 55 | + keys = list(results.keys()) |
| 56 | + params, names, devices, bitsizes = zip(*keys) |
| 57 | + |
| 58 | + devices_names = sorted(list(set(zip(devices, names)))) |
| 59 | + params = sorted(list(set(params))) |
| 60 | + bitsizes = sorted(list(set(bitsizes))) |
| 61 | + length = len(devices_names) + 1 |
| 62 | + cpus_cols = list(devices).count("CPU") / len(bitsizes) / len(params) |
| 63 | + gpus_cols = list(devices).count("GPU") / len(bitsizes) / len(params) |
| 64 | + assert cpus_cols + gpus_cols == len(devices_names) |
| 65 | + |
| 66 | + if main_title is not None: |
| 67 | + string += f'<tr><th align="center" colspan="{length}">{str(main_title)}</th></tr>\n' |
| 68 | + |
| 69 | + for i, bitsize in enumerate(bitsizes): |
| 70 | + |
| 71 | + if i != 0: |
| 72 | + string += f'<tr><td colspan="{length}"> </td></tr>\n' |
| 73 | + |
| 74 | + # make bitsize header |
| 75 | + text = f"{bitsize} bits" |
| 76 | + if comments is not None: |
| 77 | + text += " - " |
| 78 | + if isinstance(comments, (tuple, list)) and len(comments) == len(bitsizes): |
| 79 | + text += str(comments[i]) |
| 80 | + else: |
| 81 | + text += str(comments) |
| 82 | + string += f'<tr><th align="center">Bitsize</th>' |
| 83 | + string += f'<th align="center" colspan="{length - 1}">{text}</th></tr>\n' |
| 84 | + |
| 85 | + # make device header |
| 86 | + string += f'<tr><th align="center">Device</th>' |
| 87 | + string += f'<th align="center" colspan="{cpus_cols}">CPU</th>' |
| 88 | + string += f'<th align="center" colspan="{gpus_cols}">GPU</th></tr>\n' |
| 89 | + |
| 90 | + # make param_name / backend header |
| 91 | + string += f'<tr><th align="center">{param_name}</th>' |
| 92 | + for device, name in devices_names: |
| 93 | + string += f'<th align="center">{name}</th>' |
| 94 | + string += "</tr>\n" |
| 95 | + |
| 96 | + # make results rows |
| 97 | + for param in params: |
| 98 | + string += f'<tr><td align="center">{param}</td>' |
| 99 | + for device, name in devices_names: |
| 100 | + key = (param, name, device, bitsize) |
| 101 | + string += f'<td align="center">{results[key]:.4f}</td>' |
| 102 | + string += "</tr>\n" |
| 103 | + |
| 104 | + string += "</table>" |
| 105 | + return string |
0 commit comments