forked from opparco/stable-diffusion-webui-composable-lora
-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathplot_helper.py
73 lines (64 loc) · 2.14 KB
/
plot_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from dataclasses import dataclass
from typing import List
import io
import matplotlib
matplotlib.use('Agg')
import pandas as pd
from pandas.plotting._matplotlib.style import get_standard_colors
from PIL import Image
@dataclass
class YAxis:
name: str
columns: List[str]
@dataclass
class PlotDefinition:
title: str
x_axis: str
y_axis: List[YAxis]
def plot_lora_weight(lora_weights, lora_names):
data = pd.DataFrame(lora_weights, columns=lora_names)
ax = data.plot()
ax.set_xlabel("Steps")
ax.set_ylabel("LoRA weight")
ax.set_title("LoRA weight in all steps")
ax.legend(loc=0)
result_image = fig2img(ax)
matplotlib.pyplot.close(ax.figure)
del ax # RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`). Consider using `matplotlib.pyplot.close()`.
return result_image
def fig2img(fig):
buf = io.BytesIO()
fig.figure.savefig(buf)
buf.seek(0)
img = Image.open(buf)
return img
def plot_graph(
data: pd.DataFrame,
plot_definition: PlotDefinition,
spacing: float = 0.1,
):
colors = get_standard_colors(num_colors=(len(plot_definition.y_axis) + 7))
loss_color = colors[0]
avg_colors = colors[1:]
for i, yi in enumerate(plot_definition.y_axis):
if i == 0:
ax = data.plot(
x=plot_definition.x_axis,
y=yi.columns,
title=plot_definition.title,
color=[loss_color] * len(yi.columns)
)
ax.set_ylabel(ylabel=yi.name)
else:
# Multiple y-axes
ax_new = ax.twinx()
ax_new.spines["right"].set_position(("axes", 1 + spacing * (i - 1)))
data.plot(
ax=ax_new,
x=plot_definition.x_axis,
y=yi.columns,
color=[avg_colors[yl] for yl in range(len(yi.columns))]
)
ax_new.set_ylabel(ylabel=yi.name)
ax.legend(loc=0)
return ax