Skip to content

Commit 14992e2

Browse files
authored
Programmable video source (whitphx#1349)
* Create VideoSourceTrack * Remove queue * Change the frame buffer to be an instance of av.VideoFrame * Fix VideoSourceTrack to be callback-based * Fix the factory function * Fix * Stop the source tracks automatically * Stop the relayed source tracks automatically * Stop the source track in the sample * Pass pts and time_base to the callback * Make fps configurable * Fix * Export VideoSourceCallback * Fix * Fix type annotations * Fix * Fix
1 parent 7d17342 commit 14992e2

File tree

5 files changed

+194
-6
lines changed

5 files changed

+194
-6
lines changed

Diff for: pages/14_programmable_source.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import fractions
2+
import time
3+
4+
import av
5+
import cv2
6+
import numpy as np
7+
import streamlit as st
8+
from streamlit_webrtc import WebRtcMode, create_video_source_track, webrtc_streamer
9+
10+
thickness = st.slider("thickness", 1, 10, 3, 1)
11+
12+
13+
def video_source_callback(pts: int, time_base: fractions.Fraction) -> av.VideoFrame:
14+
pts_sec = pts * time_base
15+
16+
buffer = np.zeros((480, 640, 3), dtype=np.uint8)
17+
buffer = cv2.putText(
18+
buffer,
19+
text=f"time: {time.time():.2f}",
20+
org=(0, 32),
21+
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
22+
fontScale=1.0,
23+
color=(255, 255, 0),
24+
thickness=thickness,
25+
lineType=cv2.LINE_4,
26+
)
27+
buffer = cv2.putText(
28+
buffer,
29+
text=f"pts: {pts} ({float(pts_sec):.2f} sec)",
30+
org=(0, 64),
31+
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
32+
fontScale=1.0,
33+
color=(255, 255, 0),
34+
thickness=thickness,
35+
lineType=cv2.LINE_4,
36+
)
37+
return av.VideoFrame.from_ndarray(buffer, format="bgr24")
38+
39+
40+
fps = st.slider("fps", 1, 30, 30, 1)
41+
42+
43+
video_source_track = create_video_source_track(
44+
video_source_callback, key="video_source_track", fps=fps
45+
)
46+
47+
48+
def on_change():
49+
ctx = st.session_state["player"]
50+
stopped = not ctx.state.playing and not ctx.state.signalling
51+
if stopped:
52+
video_source_track.stop() # Manually stop the track.
53+
54+
55+
webrtc_streamer(
56+
key="player",
57+
mode=WebRtcMode.RECVONLY,
58+
source_video_track=video_source_track,
59+
media_stream_constraints={"video": True, "audio": False},
60+
on_change=on_change,
61+
)

Diff for: streamlit_webrtc/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@
2323
Translations,
2424
VideoHTMLAttributes,
2525
)
26-
from .factory import create_mix_track, create_process_track
26+
from .factory import create_mix_track, create_process_track, create_video_source_track
2727
from .mix import MixerCallback
28+
from .source import VideoSourceCallback, VideoSourceTrack
2829
from .webrtc import (
2930
AudioProcessorBase,
3031
AudioProcessorFactory,
@@ -64,6 +65,9 @@
6465
"VideoProcessorFactory",
6566
"VideoTransformerBase", # XXX: Deprecated
6667
"VideoReceiver",
68+
"VideoSourceTrack",
69+
"VideoSourceCallback",
70+
"create_video_source_track",
6771
"WebRtcMode",
6872
"WebRtcWorker",
6973
"MediaStreamConstraints",

Diff for: streamlit_webrtc/factory.py

+25
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
VideoProcessTrack,
3232
)
3333
from .relay import get_global_relay
34+
from .source import VideoSourceCallback, VideoSourceTrack
3435

3536
_PROCESSOR_TRACK_CACHE_KEY_PREFIX = "__PROCESSOR_TRACK_CACHE__"
3637

@@ -191,3 +192,27 @@ def create_mix_track(
191192
)
192193
st.session_state[cache_key] = mixer_track
193194
return mixer_track
195+
196+
197+
_VIDEO_SOURCE_TRACK_CACHE_KEY_PREFIX = "__VIDEO_SOURCE_TRACK_CACHE__"
198+
199+
200+
def create_video_source_track(
201+
callback: VideoSourceCallback,
202+
key: str,
203+
fps=30,
204+
) -> VideoSourceTrack:
205+
cache_key = _VIDEO_SOURCE_TRACK_CACHE_KEY_PREFIX + key
206+
if (
207+
cache_key in st.session_state
208+
and isinstance(st.session_state[cache_key], VideoSourceTrack)
209+
and st.session_state[cache_key].kind == "video"
210+
and st.session_state[cache_key].readyState == "live"
211+
):
212+
video_source_track: VideoSourceTrack = st.session_state[cache_key]
213+
video_source_track._callback = callback
214+
video_source_track._fps = fps
215+
else:
216+
video_source_track = VideoSourceTrack(callback=callback, fps=fps)
217+
st.session_state[cache_key] = video_source_track
218+
return video_source_track

Diff for: streamlit_webrtc/source.py

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import asyncio
2+
import fractions
3+
import logging
4+
import time
5+
from typing import Callable, Optional, Union
6+
7+
import av
8+
from aiortc import MediaStreamTrack
9+
from aiortc.mediastreams import MediaStreamError
10+
11+
logger = logging.getLogger(__name__)
12+
13+
# Copied from https://github.com/aiortc/aiortc/blob/main/src/aiortc/mediastreams.py
14+
AUDIO_PTIME = 0.020 # 20ms audio packetization
15+
VIDEO_CLOCK_RATE = 90000
16+
VIDEO_TIME_BASE = fractions.Fraction(1, VIDEO_CLOCK_RATE)
17+
18+
19+
# Ref: VideoStreamTrack and AudioStreamTrack in
20+
# https://github.com/aiortc/aiortc/blob/main/src/aiortc/mediastreams.py
21+
22+
23+
VideoSourceCallback = Callable[
24+
[int, fractions.Fraction], av.VideoFrame
25+
] # (pts, time_base) -> frame
26+
27+
28+
class VideoSourceTrack(MediaStreamTrack):
29+
def __init__(self, callback: VideoSourceCallback, fps: Union[int, float]) -> None:
30+
super().__init__()
31+
self.kind = "video"
32+
self._callback = callback
33+
self._fps = fps
34+
self._started_at: Optional[float] = None
35+
self._pts: Optional[int] = None
36+
37+
async def recv(self) -> av.frame.Frame:
38+
if self.readyState != "live":
39+
raise MediaStreamError
40+
41+
if self._started_at is None or self._pts is None:
42+
self._started_at = time.monotonic()
43+
self._pts = 0
44+
45+
frame = self._call_callback(self._pts, VIDEO_TIME_BASE)
46+
else:
47+
self._pts += int(VIDEO_CLOCK_RATE / self._fps)
48+
49+
frame = self._call_callback(self._pts, VIDEO_TIME_BASE)
50+
51+
wait = self._started_at + (self._pts / VIDEO_CLOCK_RATE) - time.monotonic()
52+
if wait < 0:
53+
logger.warning(
54+
"VideoSourceCallbackTrack: Video frame callback is too slow."
55+
)
56+
wait = 0
57+
await asyncio.sleep(wait)
58+
59+
return frame
60+
61+
def _call_callback(self, pts: int, time_base: fractions.Fraction) -> av.VideoFrame:
62+
try:
63+
frame = self._callback(pts, time_base)
64+
except Exception as exc:
65+
logger.error(
66+
"VideoSourceCallbackTrack: Video frame callback raised an exception: %s", # noqa: E501
67+
exc,
68+
)
69+
raise
70+
71+
frame.pts = pts
72+
frame.time_base = time_base
73+
return frame

Diff for: streamlit_webrtc/webrtc.py

+30-5
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,8 @@ class WebRtcWorker(Generic[VideoProcessorT, AudioProcessorT]):
308308
_output_video_track: Optional[MediaStreamTrack]
309309
_output_audio_track: Optional[MediaStreamTrack]
310310
_player: Optional[MediaPlayer]
311+
_relayed_source_video_track: Optional[MediaRelay]
312+
_relayed_source_audio_track: Optional[MediaRelay]
311313

312314
@property
313315
def video_processor(
@@ -404,6 +406,8 @@ def __init__(
404406
self._output_video_track = None
405407
self._output_audio_track = None
406408
self._player = None
409+
self._relayed_source_video_track = None
410+
self._relayed_source_audio_track = None
407411

408412
self._session_shutdown_observer = SessionShutdownObserver(self.stop)
409413

@@ -504,14 +508,20 @@ def on_track_created(track_type: TrackType, track: MediaStreamTrack):
504508
player = self.player_factory()
505509
self._player = player
506510
if player.audio:
507-
source_audio_track = relay.subscribe(player.audio)
511+
source_audio_track = player.audio
508512
if player.video:
509-
source_video_track = relay.subscribe(player.video)
513+
source_video_track = player.video
510514
else:
511-
if self.source_video_track:
512-
source_video_track = relay.subscribe(self.source_video_track)
513515
if self.source_audio_track:
514-
source_audio_track = relay.subscribe(self.source_audio_track)
516+
self._relayed_source_audio_track = relay.subscribe(
517+
self.source_audio_track
518+
)
519+
source_audio_track = self._relayed_source_audio_track
520+
if self.source_video_track:
521+
self._relayed_source_video_track = relay.subscribe(
522+
self.source_video_track
523+
)
524+
source_video_track = self._relayed_source_video_track
515525

516526
@self.pc.on("iceconnectionstatechange")
517527
async def on_iceconnectionstatechange():
@@ -655,6 +665,21 @@ def _unset_processors(self):
655665
self._player.audio.stop()
656666
self._player = None
657667

668+
# Same as above,
669+
# the source tracks are not automatically stopped when the WebRTC.
670+
# Only the relayed tracks are stopped here because
671+
# the upstream tracks may still be used by other consumers.
672+
if self._relayed_source_audio_track:
673+
logger.debug("Stopping the relayed source audio track")
674+
self._relayed_source_audio_track.stop()
675+
self.source_audio_track = None
676+
self._relayed_source_audio_track = None
677+
if self._relayed_source_video_track:
678+
logger.debug("Stopping the relayed source video track")
679+
self._relayed_source_video_track.stop()
680+
self.source_video_track = None
681+
self._relayed_source_video_track = None
682+
658683
def stop(self, timeout: Union[float, None] = 1.0):
659684
self._unset_processors()
660685
if self._process_offer_thread:

0 commit comments

Comments
 (0)