|
5 | 5 |
|
6 | 6 | import math
|
7 | 7 | import numbers
|
8 |
| -from typing import Iterable, Optional, Sequence |
| 8 | +from typing import Iterable, List, Optional, Sequence |
9 | 9 |
|
10 | 10 | import numpy as np
|
11 | 11 |
|
@@ -246,18 +246,13 @@ def create_axis(self, range_terms, axis_config, length):
|
246 | 246 | new_config["length"] = length
|
247 | 247 | axis = NumberLine(range_terms, **new_config)
|
248 | 248 |
|
249 |
| - # without the if/elif, graph does not exist when min > 0 or max < 0 |
| 249 | + # without the call to origin_shift, graph does not exist when min > 0 or max < 0 |
250 | 250 | # shifts the axis so that 0 is centered
|
251 |
| - if range_terms[0] > 0: |
252 |
| - axis.shift(-axis.number_to_point(range_terms[0])) |
253 |
| - elif range_terms[1] < 0: |
254 |
| - axis.shift(-axis.number_to_point(range_terms[1])) |
255 |
| - else: |
256 |
| - axis.shift(-axis.number_to_point(0)) |
| 251 | + axis.shift(-axis.number_to_point(self.origin_shift(range_terms))) |
257 | 252 | return axis
|
258 | 253 |
|
259 | 254 | def coords_to_point(self, *coords):
|
260 |
| - origin = self.x_axis.number_to_point(0) |
| 255 | + origin = self.x_axis.number_to_point(self.origin_shift(self.x_range)) |
261 | 256 | result = np.array(origin)
|
262 | 257 | for axis, coord in zip(self.get_axes(), coords):
|
263 | 258 | result += axis.number_to_point(coord) - origin
|
@@ -367,6 +362,22 @@ def construct(self):
|
367 | 362 |
|
368 | 363 | return line_graph
|
369 | 364 |
|
| 365 | + @staticmethod |
| 366 | + def origin_shift(axis_range: List[float]) -> float: |
| 367 | + """Determines how to shift graph mobjects to compensate when 0 is not on the axis. |
| 368 | +
|
| 369 | + Parameters |
| 370 | + ---------- |
| 371 | + axis_range |
| 372 | + The range of the axis : ``(x_min, x_max, x_step)``. |
| 373 | + """ |
| 374 | + if axis_range[0] > 0: |
| 375 | + return axis_range[0] |
| 376 | + if axis_range[1] < 0: |
| 377 | + return axis_range[1] |
| 378 | + else: |
| 379 | + return 0 |
| 380 | + |
370 | 381 |
|
371 | 382 | class ThreeDAxes(Axes):
|
372 | 383 | """A 3-dimensional set of axes.
|
@@ -444,7 +455,7 @@ def __init__(
|
444 | 455 | z_axis = self.create_axis(self.z_range, self.z_axis_config, self.z_length)
|
445 | 456 | z_axis.rotate_about_zero(-PI / 2, UP)
|
446 | 457 | z_axis.rotate_about_zero(angle_of_vector(self.z_normal))
|
447 |
| - z_axis.shift(self.x_axis.n2p(0)) |
| 458 | + z_axis.shift(self.x_axis.number_to_point(self.origin_shift(x_range))) |
448 | 459 |
|
449 | 460 | self.axes.add(z_axis)
|
450 | 461 | self.add(z_axis)
|
|
0 commit comments