Skip to content

Commit 281fc0a

Browse files
committed
BUG: Allow multiple names for vector indicators (#382)
Previously we only allowed one name per vector indicator: def _my_indicator(open, close): return tuple( _my_indicator_one(open, close), _my_indicator_two(open, close), ) self.I( _my_indicator, # One name is used to describe two values name="My Indicator", self.data.Open, self.data.Close ) Now, the user can supply two (or more) names to annotate each value individually. The names will be shown in the plot legend. The following is now valid: self.I( _my_indicator, # One name is used to describe two values name=["My Indicator One", "My Indicator Two"], self.data.Open, self.data.Close )
1 parent 0ce24d8 commit 281fc0a

File tree

3 files changed

+72
-13
lines changed

3 files changed

+72
-13
lines changed

backtesting/_plotting.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import sys
44
import warnings
55
from colorsys import hls_to_rgb, rgb_to_hls
6-
from itertools import cycle, combinations
6+
from itertools import cycle, combinations, repeat
77
from functools import partial
88
from typing import Callable, List, Union
99

@@ -537,10 +537,24 @@ def __eq__(self, other):
537537
colors = value._opts['color']
538538
colors = colors and cycle(_as_list(colors)) or (
539539
cycle([next(ohlc_colors)]) if is_overlay else colorgen())
540-
legend_label = LegendStr(value.name)
541-
for j, arr in enumerate(value, 1):
540+
541+
tooltip_label = value.name
542+
543+
if len(value) == 1:
544+
assert isinstance(value.name, str)
545+
legend_labels = [LegendStr(value.name)]
546+
elif isinstance(value.name, str):
547+
legend_labels = [
548+
LegendStr(f"{name}[{index}]")
549+
for index, name in enumerate(repeat(value.name, len(value)))
550+
]
551+
else:
552+
legend_labels = [LegendStr(item) for item in value.name]
553+
tooltip_label = ", ".join(value.name)
554+
555+
for j, arr in enumerate(value):
542556
color = next(colors)
543-
source_name = f'{legend_label}_{i}_{j}'
557+
source_name = f'{legend_labels[j]}_{i}_{j}'
544558
if arr.dtype == bool:
545559
arr = arr.astype(int)
546560
source.add(arr, source_name)
@@ -550,24 +564,24 @@ def __eq__(self, other):
550564
if is_scatter:
551565
fig.scatter(
552566
'index', source_name, source=source,
553-
legend_label=legend_label, color=color,
567+
legend_label=legend_labels[j], color=color,
554568
line_color='black', fill_alpha=.8,
555569
marker='circle', radius=BAR_WIDTH / 2 * 1.5)
556570
else:
557571
fig.line(
558572
'index', source_name, source=source,
559-
legend_label=legend_label, line_color=color,
573+
legend_label=legend_labels[j], line_color=color,
560574
line_width=1.3)
561575
else:
562576
if is_scatter:
563577
r = fig.scatter(
564578
'index', source_name, source=source,
565-
legend_label=LegendStr(legend_label), color=color,
579+
legend_label=legend_labels[j], color=color,
566580
marker='circle', radius=BAR_WIDTH / 2 * .9)
567581
else:
568582
r = fig.line(
569583
'index', source_name, source=source,
570-
legend_label=LegendStr(legend_label), line_color=color,
584+
legend_label=legend_labels[j], line_color=color,
571585
line_width=1.3)
572586
# Add dashed centerline just because
573587
mean = float(pd.Series(arr).mean())
@@ -578,9 +592,9 @@ def __eq__(self, other):
578592
line_color='#666666', line_dash='dashed',
579593
line_width=.5))
580594
if is_overlay:
581-
ohlc_tooltips.append((legend_label, NBSP.join(tooltips)))
595+
ohlc_tooltips.append((tooltip_label, NBSP.join(tooltips)))
582596
else:
583-
set_tooltips(fig, [(legend_label, NBSP.join(tooltips))], vline=True, renderers=[r])
597+
set_tooltips(fig, [(tooltip_label, NBSP.join(tooltips))], vline=True, renderers=[r])
584598
# If the sole indicator line on this figure,
585599
# have the legend only contain text without the glyph
586600
if len(value) == 1:

backtesting/backtesting.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ def I(self, # noqa: E743
9090
same length as `backtesting.backtesting.Strategy.data`.
9191
9292
In the plot legend, the indicator is labeled with
93-
function name, unless `name` overrides it.
93+
function name, unless `name` overrides it. If `func` returns
94+
multiple arrays, `name` can be a collection of strings, and
95+
the size must agree with the number of arrays returned.
9496
9597
If `plot` is `True`, the indicator is plotted on the resulting
9698
`backtesting.backtesting.Backtest.plot`.
@@ -115,13 +117,21 @@ def I(self, # noqa: E743
115117
def init():
116118
self.sma = self.I(ta.SMA, self.data.Close, self.n_sma)
117119
"""
120+
def _format_name(name: str) -> str:
121+
return name.format(*map(_as_str, args),
122+
**dict(zip(kwargs.keys(), map(_as_str, kwargs.values()))))
123+
118124
if name is None:
119125
params = ','.join(filter(None, map(_as_str, chain(args, kwargs.values()))))
120126
func_name = _as_str(func)
121127
name = (f'{func_name}({params})' if params else f'{func_name}')
128+
elif isinstance(name, str):
129+
name = _format_name(name)
130+
elif try_(lambda: all(isinstance(item, str) for item in name), False):
131+
name = [_format_name(item) for item in name]
122132
else:
123-
name = name.format(*map(_as_str, args),
124-
**dict(zip(kwargs.keys(), map(_as_str, kwargs.values()))))
133+
raise TypeError(f'Unexpected `name` type {type(name)}, `str` or `Iterable[str]` '
134+
'was expected.')
125135

126136
try:
127137
value = func(*args, **kwargs)
@@ -139,6 +149,11 @@ def init():
139149
if is_arraylike and np.argmax(value.shape) == 0:
140150
value = value.T
141151

152+
if isinstance(name, list) and (value.ndim != 2 or value.shape[0] != len(name)):
153+
raise ValueError(
154+
f'The number of `name` elements ({len(name)}) must agree with the nubmer '
155+
f'of arrays ({value.shape[0]}) the indicator returns.')
156+
142157
if not is_arraylike or not 1 <= value.ndim <= 2 or value.shape[-1] != len(self._data.Close):
143158
raise ValueError(
144159
'Indicators must return (optionally a tuple of) numpy.arrays of same '

backtesting/test/_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,36 @@ def test_resample(self):
755755
# Give browser time to open before tempfile is removed
756756
time.sleep(1)
757757

758+
def test_indicator_name(self):
759+
test_self = self
760+
761+
class S(Strategy):
762+
def init(self):
763+
def _SMA():
764+
return SMA(self.data.Close, 5), SMA(self.data.Close, 10)
765+
766+
test_self.assertRaises(TypeError, self.I, _SMA, name=42)
767+
test_self.assertRaises(ValueError, self.I, _SMA, name=("SMA One", ))
768+
test_self.assertRaises(
769+
ValueError, self.I, _SMA, name=("SMA One", "SMA Two", "SMA Three"))
770+
771+
for overlay in (True, False):
772+
self.I(SMA, self.data.Close, 5, overlay=overlay)
773+
self.I(SMA, self.data.Close, 5, name="My SMA", overlay=overlay)
774+
self.I(_SMA, overlay=overlay)
775+
self.I(_SMA, name="My SMA", overlay=overlay)
776+
self.I(_SMA, name=("SMA One", "SMA Two"), overlay=overlay)
777+
778+
def next(self):
779+
pass
780+
781+
bt = Backtest(GOOG, S)
782+
bt.run()
783+
with _tempfile() as f:
784+
bt.plot(filename=f,
785+
plot_drawdown=False, plot_equity=False, plot_pl=False, plot_volume=False,
786+
open_browser=False)
787+
758788
def test_indicator_color(self):
759789
class S(Strategy):
760790
def init(self):

0 commit comments

Comments
 (0)