Skip to content

Commit d597f90

Browse files
add searchsorted backend method
1 parent 4087817 commit d597f90

File tree

8 files changed

+46
-2
lines changed

8 files changed

+46
-2
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
- Add circuit `from_qsim_file` method to load Google random circuit structure
2020

21+
- Add `searchsorted` method for backend
22+
2123
### Changed
2224

2325
- The inner mechanism for `sample_expectation_ps` is changed to sample representation from count representation for a fast speed

tensorcircuit/backends/abstract_backend.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,26 @@ def solve(self: Any, A: Tensor, b: Tensor, **kws: Any) -> Tensor:
746746
"Backend '{}' has not implemented `solve`.".format(self.name)
747747
)
748748

749+
def searchsorted(self: Any, a: Tensor, v: Tensor, side: str = "left") -> Tensor:
750+
"""
751+
Find indices where elements should be inserted to maintain order.
752+
753+
:param a: input array sorted in ascending order
754+
:type a: Tensor
755+
:param v: value to inserted
756+
:type v: Tensor
757+
:param side: If ‘left’, the index of the first suitable location found is given.
758+
If ‘right’, return the last such index.
759+
If there is no suitable index, return either 0 or N (where N is the length of a),
760+
defaults to "left"
761+
:type side: str, optional
762+
:return: Array of insertion points with the same shape as v, or an integer if v is a scalar.
763+
:rtype: Tensor
764+
"""
765+
raise NotImplementedError(
766+
"Backend '{}' has not implemented `searchsorted`.".format(self.name)
767+
)
768+
749769
def tree_map(self: Any, f: Callable[..., Any], *pytrees: Any) -> Any:
750770
"""
751771
Return the new tree map with multiple arg function ``f`` through pytrees.

tensorcircuit/backends/jax_backend.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,13 @@ def is_tensor(self, a: Any) -> bool:
389389
def solve(self, A: Tensor, b: Tensor, assume_a: str = "gen") -> Tensor: # type: ignore
390390
return jsp.linalg.solve(A, b, assume_a)
391391

392+
def searchsorted(self, a: Tensor, v: Tensor, side: str = "left") -> Tensor:
393+
if not self.is_tensor(a):
394+
a = self.convert_to_tensor(a)
395+
if not self.is_tensor(v):
396+
v = self.convert_to_tensor(v)
397+
return jnp.searchsorted(a, v, side)
398+
392399
def tree_map(self, f: Callable[..., Any], *pytrees: Any) -> Any:
393400
return libjax.tree_map(f, *pytrees)
394401

tensorcircuit/backends/numpy_backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,9 @@ def solve(self, A: Tensor, b: Tensor, assume_a: str = "gen") -> Tensor: # type:
223223
# https://stackoverflow.com/questions/44672029/difference-between-numpy-linalg-solve-and-numpy-linalg-lu-solve/44710451
224224
return solve(A, b, assume_a)
225225

226+
def searchsorted(self, a: Tensor, v: Tensor, side: str = "left") -> Tensor:
227+
return np.searchsorted(a, v, side=side) # type: ignore
228+
226229
def set_random_state(
227230
self, seed: Optional[int] = None, get_only: bool = False
228231
) -> Any:

tensorcircuit/backends/pytorch_backend.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,13 @@ def left_shift(self, x: Tensor, y: Tensor) -> Tensor:
389389
def solve(self, A: Tensor, b: Tensor, **kws: Any) -> Tensor:
390390
return torchlib.linalg.solve(A, b)
391391

392+
def searchsorted(self, a: Tensor, v: Tensor, side: str = "left") -> Tensor:
393+
if not self.is_tensor(a):
394+
a = self.convert_to_tensor(a)
395+
if not self.is_tensor(v):
396+
v = self.convert_to_tensor(v)
397+
return torchlib.searchsorted(a, v, side=side)
398+
392399
def reverse(self, a: Tensor) -> Tensor:
393400
return torchlib.flip(a, dims=(-1,))
394401

tensorcircuit/backends/tensorflow_backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,9 @@ def solve(self, A: Tensor, b: Tensor, **kws: Any) -> Tensor:
434434
return self.reshape(x, x.shape[:-1])
435435
return x
436436

437+
def searchsorted(self, a: Tensor, v: Tensor, side: str = "left") -> Tensor:
438+
return tf.searchsorted(a, v, side)
439+
437440
def from_dlpack(self, a: Any) -> Tensor:
438441
return tf.experimental.dlpack.from_dlpack(a)
439442

tensorcircuit/experimental.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,6 @@ def parameter_shift_grad(
226226
:return: the grad function
227227
:rtype: Callable[..., Tensor]
228228
"""
229-
# TODO(@refraction-ray): finite shot sample_expectation_ps not supported well for now
230229
if jit is True:
231230
f = backend.jit(f)
232231

@@ -283,7 +282,6 @@ def parameter_shift_grad_v2(
283282
:return: the grad function
284283
:rtype: Callable[..., Tensor]
285284
"""
286-
# TODO(@refraction-ray): finite shot sample_expectation_ps not supported well for now
287285
if jit is True:
288286
f = backend.jit(f)
289287

tests/test_backends.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,10 @@ def test_backend_methods_2(backend):
279279
np.array([1, 3]),
280280
)
281281
assert tc.backend.dtype(tc.backend.ones([])) == "complex64"
282+
edges = [-1, 3.3, 9.1, 10.0]
283+
values = tc.backend.convert_to_tensor(np.array([0.0, 4.1, 12.0], dtype=np.float32))
284+
r = tc.backend.numpy(tc.backend.searchsorted(edges, values))
285+
np.testing.assert_allclose(r, np.array([1, 2, 4]))
282286

283287

284288
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])

0 commit comments

Comments
 (0)