Skip to content

Commit 0efdc84

Browse files
authored
Merge pull request #1 from dastpis/feature/multiple-indicator-labels
feat: add multiple vector names handling
2 parents 0ce24d8 + 57d9f51 commit 0efdc84

File tree

3 files changed

+71
-12
lines changed

3 files changed

+71
-12
lines changed

backtesting/_plotting.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -537,10 +537,23 @@ def __eq__(self, other):
537537
colors = value._opts['color']
538538
colors = colors and cycle(_as_list(colors)) or (
539539
cycle([next(ohlc_colors)]) if is_overlay else colorgen())
540-
legend_label = LegendStr(value.name)
541-
for j, arr in enumerate(value, 1):
540+
541+
if isinstance(value.name, str):
542+
tooltip_label = value.name
543+
if len(value) == 1:
544+
legend_labels = [LegendStr(value.name)]
545+
else:
546+
legend_labels = [
547+
LegendStr(f"{value.name}[{i}]")
548+
for i in range(len(value))
549+
]
550+
else:
551+
tooltip_label = ", ".join(value.name)
552+
legend_labels = [LegendStr(item) for item in value.name]
553+
554+
for j, arr in enumerate(value):
542555
color = next(colors)
543-
source_name = f'{legend_label}_{i}_{j}'
556+
source_name = f'{legend_labels[j]}_{i}_{j}'
544557
if arr.dtype == bool:
545558
arr = arr.astype(int)
546559
source.add(arr, source_name)
@@ -550,24 +563,24 @@ def __eq__(self, other):
550563
if is_scatter:
551564
fig.scatter(
552565
'index', source_name, source=source,
553-
legend_label=legend_label, color=color,
566+
legend_label=legend_labels[j], color=color,
554567
line_color='black', fill_alpha=.8,
555568
marker='circle', radius=BAR_WIDTH / 2 * 1.5)
556569
else:
557570
fig.line(
558571
'index', source_name, source=source,
559-
legend_label=legend_label, line_color=color,
572+
legend_label=legend_labels[j], line_color=color,
560573
line_width=1.3)
561574
else:
562575
if is_scatter:
563576
r = fig.scatter(
564577
'index', source_name, source=source,
565-
legend_label=LegendStr(legend_label), color=color,
578+
legend_label=legend_labels[j], color=color,
566579
marker='circle', radius=BAR_WIDTH / 2 * .9)
567580
else:
568581
r = fig.line(
569582
'index', source_name, source=source,
570-
legend_label=LegendStr(legend_label), line_color=color,
583+
legend_label=legend_labels[j], line_color=color,
571584
line_width=1.3)
572585
# Add dashed centerline just because
573586
mean = float(pd.Series(arr).mean())
@@ -578,9 +591,9 @@ def __eq__(self, other):
578591
line_color='#666666', line_dash='dashed',
579592
line_width=.5))
580593
if is_overlay:
581-
ohlc_tooltips.append((legend_label, NBSP.join(tooltips)))
594+
ohlc_tooltips.append((tooltip_label, NBSP.join(tooltips)))
582595
else:
583-
set_tooltips(fig, [(legend_label, NBSP.join(tooltips))], vline=True, renderers=[r])
596+
set_tooltips(fig, [(tooltip_label, NBSP.join(tooltips))], vline=True, renderers=[r])
584597
# If the sole indicator line on this figure,
585598
# have the legend only contain text without the glyph
586599
if len(value) == 1:

backtesting/backtesting.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ def I(self, # noqa: E743
9090
same length as `backtesting.backtesting.Strategy.data`.
9191
9292
In the plot legend, the indicator is labeled with
93-
function name, unless `name` overrides it.
93+
function name, unless `name` overrides it. If `func` returns
94+
multiple arrays, `name` can be a sequence of strings, and
95+
its size must agree with the number of arrays returned.
9496
9597
If `plot` is `True`, the indicator is plotted on the resulting
9698
`backtesting.backtesting.Backtest.plot`.
@@ -115,13 +117,21 @@ def I(self, # noqa: E743
115117
def init():
116118
self.sma = self.I(ta.SMA, self.data.Close, self.n_sma)
117119
"""
120+
def _format_name(name: str) -> str:
121+
return name.format(*map(_as_str, args),
122+
**dict(zip(kwargs.keys(), map(_as_str, kwargs.values()))))
123+
118124
if name is None:
119125
params = ','.join(filter(None, map(_as_str, chain(args, kwargs.values()))))
120126
func_name = _as_str(func)
121127
name = (f'{func_name}({params})' if params else f'{func_name}')
128+
elif isinstance(name, str):
129+
name = _format_name(name)
130+
elif try_(lambda: all(isinstance(item, str) for item in name), False):
131+
name = [_format_name(item) for item in name]
122132
else:
123-
name = name.format(*map(_as_str, args),
124-
**dict(zip(kwargs.keys(), map(_as_str, kwargs.values()))))
133+
raise TypeError(f'Unexpected `name=` type {type(name)}; expected `str` or '
134+
'`Sequence[str]`')
125135

126136
try:
127137
value = func(*args, **kwargs)
@@ -139,6 +149,11 @@ def init():
139149
if is_arraylike and np.argmax(value.shape) == 0:
140150
value = value.T
141151

152+
if isinstance(name, list) and (np.atleast_2d(value).shape[0] != len(name)):
153+
raise ValueError(
154+
f'Length of `name=` ({len(name)}) must agree with the number '
155+
f'of arrays the indicator returns ({value.shape[0]}).')
156+
142157
if not is_arraylike or not 1 <= value.ndim <= 2 or value.shape[-1] != len(self._data.Close):
143158
raise ValueError(
144159
'Indicators must return (optionally a tuple of) numpy.arrays of same '

backtesting/test/_test.py

+31
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,37 @@ def test_resample(self):
755755
# Give browser time to open before tempfile is removed
756756
time.sleep(1)
757757

758+
def test_indicator_name(self):
759+
test_self = self
760+
761+
class S(Strategy):
762+
def init(self):
763+
def _SMA():
764+
return SMA(self.data.Close, 5), SMA(self.data.Close, 10)
765+
766+
test_self.assertRaises(TypeError, self.I, _SMA, name=42)
767+
test_self.assertRaises(ValueError, self.I, _SMA, name=("SMA One",))
768+
test_self.assertRaises(
769+
ValueError, self.I, _SMA, name=("SMA One", "SMA Two", "SMA Three"))
770+
771+
for overlay in (True, False):
772+
self.I(SMA, self.data.Close, 5, overlay=overlay)
773+
self.I(SMA, self.data.Close, 5, name="My SMA", overlay=overlay)
774+
self.I(SMA, self.data.Close, 5, name=("My SMA",), overlay=overlay)
775+
self.I(_SMA, overlay=overlay)
776+
self.I(_SMA, name="My SMA", overlay=overlay)
777+
self.I(_SMA, name=("SMA One", "SMA Two"), overlay=overlay)
778+
779+
def next(self):
780+
pass
781+
782+
bt = Backtest(GOOG, S)
783+
bt.run()
784+
with _tempfile() as f:
785+
bt.plot(filename=f,
786+
plot_drawdown=False, plot_equity=False, plot_pl=False, plot_volume=False,
787+
open_browser=False)
788+
758789
def test_indicator_color(self):
759790
class S(Strategy):
760791
def init(self):

0 commit comments

Comments
 (0)