Skip to content

Commit 7d57dcc

Browse files
authored
feat: support all_users & all_sender_channels for segment (#164)
* feat: support all_users & all_sender_channels for segment * fix base type for updatable data * remove duplication
1 parent c6dc110 commit 7d57dcc

File tree

7 files changed

+58
-10
lines changed

7 files changed

+58
-10
lines changed

stream_chat/async_chat/client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
QuerySegmentTargetsOptions,
2727
SegmentData,
2828
SegmentType,
29+
SegmentUpdatableFields,
2930
)
3031

3132
if sys.version_info >= (3, 8):
@@ -591,7 +592,7 @@ async def query_segments(
591592
return await self.post("segments/query", data=payload)
592593

593594
async def update_segment(
594-
self, segment_id: str, data: SegmentData
595+
self, segment_id: str, data: SegmentUpdatableFields
595596
) -> StreamResponse:
596597
return await self.put(f"segments/{segment_id}", data=data)
597598

stream_chat/async_chat/segment.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22

33
from stream_chat.base.segment import SegmentInterface
44
from stream_chat.types.base import SortParam
5-
from stream_chat.types.segment import QuerySegmentTargetsOptions, SegmentData
5+
from stream_chat.types.segment import (
6+
QuerySegmentTargetsOptions,
7+
SegmentData,
8+
SegmentUpdatableFields,
9+
)
610
from stream_chat.types.stream_response import StreamResponse
711

812

@@ -24,24 +28,29 @@ async def create(
2428
return state
2529

2630
async def get(self) -> StreamResponse:
31+
super().verify_segment_id()
2732
return await self.client.get_segment(segment_id=self.segment_id) # type: ignore
2833

29-
async def update(self, data: SegmentData) -> StreamResponse:
34+
async def update(self, data: SegmentUpdatableFields) -> StreamResponse:
35+
super().verify_segment_id()
3036
return await self.client.update_segment( # type: ignore
3137
segment_id=self.segment_id, data=data
3238
)
3339

3440
async def delete(self) -> StreamResponse:
41+
super().verify_segment_id()
3542
return await self.client.delete_segment( # type: ignore
3643
segment_id=self.segment_id
3744
)
3845

3946
async def target_exists(self, target_id: str) -> StreamResponse:
47+
super().verify_segment_id()
4048
return await self.client.segment_target_exists( # type: ignore
4149
segment_id=self.segment_id, target_id=target_id
4250
)
4351

4452
async def add_targets(self, target_ids: list) -> StreamResponse:
53+
super().verify_segment_id()
4554
return await self.client.add_segment_targets( # type: ignore
4655
segment_id=self.segment_id, target_ids=target_ids
4756
)
@@ -52,6 +61,7 @@ async def query_targets(
5261
sort: Optional[List[SortParam]] = None,
5362
options: Optional[QuerySegmentTargetsOptions] = None,
5463
) -> StreamResponse:
64+
super().verify_segment_id()
5565
return await self.client.query_segment_targets( # type: ignore
5666
segment_id=self.segment_id,
5767
filter_conditions=filter_conditions,
@@ -60,6 +70,7 @@ async def query_targets(
6070
)
6171

6272
async def remove_targets(self, target_ids: list) -> StreamResponse:
73+
super().verify_segment_id()
6374
return await self.client.remove_segment_targets( # type: ignore
6475
segment_id=self.segment_id, target_ids=target_ids
6576
)

stream_chat/base/client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
QuerySegmentTargetsOptions,
1515
SegmentData,
1616
SegmentType,
17+
SegmentUpdatableFields,
1718
)
1819

1920
if sys.version_info >= (3, 8):
@@ -982,7 +983,7 @@ def query_segments(
982983

983984
@abc.abstractmethod
984985
def update_segment(
985-
self, segment_id: str, data: SegmentData
986+
self, segment_id: str, data: SegmentUpdatableFields
986987
) -> Union[StreamResponse, Awaitable[StreamResponse]]:
987988
"""
988989
Update a segment by id

stream_chat/base/segment.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
QuerySegmentTargetsOptions,
88
SegmentData,
99
SegmentType,
10+
SegmentUpdatableFields,
1011
)
1112
from stream_chat.types.stream_response import StreamResponse
1213

@@ -36,7 +37,7 @@ def get(self) -> Union[StreamResponse, Awaitable[StreamResponse]]:
3637

3738
@abc.abstractmethod
3839
def update(
39-
self, data: SegmentData
40+
self, data: SegmentUpdatableFields
4041
) -> Union[StreamResponse, Awaitable[StreamResponse]]:
4142
pass
4243

@@ -70,3 +71,10 @@ def remove_targets(
7071
self, target_ids: List[str]
7172
) -> Union[StreamResponse, Awaitable[StreamResponse]]:
7273
pass
74+
75+
def verify_segment_id(self) -> None:
76+
if not self.segment_id:
77+
raise ValueError(
78+
"Segment id is missing. Either create the segment using segment.create() "
79+
"or set the id during instantiation - segment = Segment(segment_id=segment_id)"
80+
)

stream_chat/client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
QuerySegmentTargetsOptions,
1616
SegmentData,
1717
SegmentType,
18+
SegmentUpdatableFields,
1819
)
1920

2021
if sys.version_info >= (3, 8):
@@ -569,7 +570,9 @@ def query_segments(
569570
payload.update(cast(dict, options))
570571
return self.post("segments/query", data=payload)
571572

572-
def update_segment(self, segment_id: str, data: SegmentData) -> StreamResponse:
573+
def update_segment(
574+
self, segment_id: str, data: SegmentUpdatableFields
575+
) -> StreamResponse:
573576
return self.put(f"segments/{segment_id}", data=data)
574577

575578
def delete_segment(self, segment_id: str) -> StreamResponse:

stream_chat/segment.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22

33
from stream_chat.base.segment import SegmentInterface
44
from stream_chat.types.base import SortParam
5-
from stream_chat.types.segment import QuerySegmentTargetsOptions, SegmentData
5+
from stream_chat.types.segment import (
6+
QuerySegmentTargetsOptions,
7+
SegmentData,
8+
SegmentUpdatableFields,
9+
)
610
from stream_chat.types.stream_response import StreamResponse
711

812

@@ -24,22 +28,27 @@ def create(
2428
return state # type: ignore
2529

2630
def get(self) -> StreamResponse:
31+
super().verify_segment_id()
2732
return self.client.get_segment(segment_id=self.segment_id) # type: ignore
2833

29-
def update(self, data: SegmentData) -> StreamResponse:
34+
def update(self, data: SegmentUpdatableFields) -> StreamResponse:
35+
super().verify_segment_id()
3036
return self.client.update_segment( # type: ignore
3137
segment_id=self.segment_id, data=data
3238
)
3339

3440
def delete(self) -> StreamResponse:
41+
super().verify_segment_id()
3542
return self.client.delete_segment(segment_id=self.segment_id) # type: ignore
3643

3744
def target_exists(self, target_id: str) -> StreamResponse:
45+
super().verify_segment_id()
3846
return self.client.segment_target_exists( # type: ignore
3947
segment_id=self.segment_id, target_id=target_id
4048
)
4149

4250
def add_targets(self, target_ids: list) -> StreamResponse:
51+
super().verify_segment_id()
4352
return self.client.add_segment_targets( # type: ignore
4453
segment_id=self.segment_id, target_ids=target_ids
4554
)
@@ -50,6 +59,7 @@ def query_targets(
5059
sort: Optional[List[SortParam]] = None,
5160
options: Optional[QuerySegmentTargetsOptions] = None,
5261
) -> StreamResponse:
62+
super().verify_segment_id()
5363
return self.client.query_segment_targets( # type: ignore
5464
segment_id=self.segment_id,
5565
sort=sort,
@@ -58,6 +68,7 @@ def query_targets(
5868
)
5969

6070
def remove_targets(self, target_ids: list) -> StreamResponse:
71+
super().verify_segment_id()
6172
return self.client.remove_segment_targets( # type: ignore
6273
segment_id=self.segment_id, target_ids=target_ids
6374
)

stream_chat/types/segment.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ class SegmentType(Enum):
2323
USER = "user"
2424

2525

26-
class SegmentData(TypedDict, total=False):
26+
class SegmentUpdatableFields(TypedDict, total=False):
2727
"""
28-
Represents the data structure for a segment.
28+
Represents the updatable data structure for a segment.
2929
3030
Parameters:
3131
name: The name of the segment.
@@ -38,6 +38,19 @@ class SegmentData(TypedDict, total=False):
3838
filter: Optional[Dict]
3939

4040

41+
class SegmentData(SegmentUpdatableFields, total=False):
42+
"""
43+
Represents the data structure for a segment.
44+
45+
Parameters:
46+
all_users: Whether to target all users.
47+
all_sender_channels: Whether to target all sender channels.
48+
"""
49+
50+
all_users: Optional[bool]
51+
all_sender_channels: Optional[bool]
52+
53+
4154
class QuerySegmentsOptions(Pager, total=False):
4255
pass
4356

0 commit comments

Comments
 (0)