@@ -1067,6 +1067,41 @@ def stateful_randc(
1067
1067
"Backend '{}' has not implemented `stateful_randc`." .format (self .name )
1068
1068
)
1069
1069
1070
+ def probability_sample (
1071
+ self : Any , shots : int , p : Tensor , status : Optional [Tensor ] = None , g : Any = None
1072
+ ) -> Tensor :
1073
+ """
1074
+ Drawn ``shots`` samples from probability distribution p, given the external randomness
1075
+ determined by uniform distributed ``status`` tensor or backend random generator ``g``.
1076
+ This method is similar with ``stateful_randc``, but it supports ``status`` beyond ``g``,
1077
+ which is convenient when jit or vmap
1078
+
1079
+ :param shots: Number of samples to draw with replacement
1080
+ :type shots: int
1081
+ :param p: prbability vector
1082
+ :type p: Tensor
1083
+ :param status: external randomness as a tensor with each element drawn uniformly from [0, 1],
1084
+ defaults to None
1085
+ :type status: Optional[Tensor], optional
1086
+ :param g: backend random genrator, defaults to None
1087
+ :type g: Any, optional
1088
+ :return: The drawn sample as an int tensor
1089
+ :rtype: Tensor
1090
+ """
1091
+ if status is not None :
1092
+ status = self .convert_to_tensor (status )
1093
+ elif g is not None :
1094
+ status = self .stateful_randu (g , shape = [shots ])
1095
+ else :
1096
+ status = self .implicit_randu (shape = [shots ])
1097
+ p = p / self .sum (p )
1098
+ p_cuml = self .cumsum (p )
1099
+ r = p_cuml [- 1 ] * (1 - self .cast (status , p .dtype ))
1100
+ ind = self .searchsorted (p_cuml , r )
1101
+ a = self .arange (shots )
1102
+ res = self .gather1d (a , ind )
1103
+ return res
1104
+
1070
1105
def gather1d (self : Any , operand : Tensor , indices : Tensor ) -> Tensor :
1071
1106
"""
1072
1107
Return ``operand[indices]``, both ``operand`` and ``indices`` are rank-1 tensor.
0 commit comments