Skip to content

modify encode_video() / encode_audio() API design #58

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ It supports seven models, four features (video and audio features), and six data
Furthermore, Lighthouse supports [audio moment retrieval](https://h-munakata.github.io/Language-based-Audio-Moment-Retrieval/), a task to identify relevant moments from an audio input based on a given text query.

## News
- [2025/05/20] [Version 1.1]() has been released. It includes API changes and AMR gradio demo.
- [2024/12/24] Our work ["Language-based audio moment retrieval"](https://arxiv.org/abs/2409.15672) has been accepted at ICASSP 2025.
- [2024/10/22] [Version 1.0](https://github.com/line/lighthouse/releases/tag/v1.0) has been released.
- [2024/10/6] Our paper has been accepted at EMNLP2024, system demonstration track.
Expand Down Expand Up @@ -43,14 +44,14 @@ device = "cuda" if torch.cuda.is_available() else "cpu"

# slowfast_path is necesary if you use clip_slowfast features
query = 'A man is speaking in front of the camera'
model = CGDETRPredictor('results/cg_detr/qvhighlight/clip_slowfast/best.ckpt', device=device,
model = CGDETRPredictor('/path/to/weight.ckpt', device=device,
feature_name='clip_slowfast', slowfast_path='SLOWFAST_8x8_R50.pkl')

# encode video features
model.encode_video('api_example/RoripwjYFp8_60.0_210.0.mp4')
video = model.encode_video('api_example/RoripwjYFp8_60.0_210.0.mp4')

# moment retrieval & highlight detection
prediction = model.predict(query)
prediction = model.predict(query, video)
print(prediction)
"""
pred_relevant_windows: [[start, end, score], ...,]
Expand All @@ -67,7 +68,20 @@ pred_saliency_scores: [score, ...]
...]}
"""
```
Run `python api_example/demo.py` to reproduce the results. It automatically downloads pre-trained weights for CG-DETR (CLIP backbone).
Lighthouse also supports the AMR inference API:
```python
import torch
from lighthouse.models import QDDETRPredictor

device = "cuda" if torch.cuda.is_available() else "cpu"
model = QDDETRPredictor('/path/to/weight.ckpt', device=device, feature_name='clap')

audio = model.encode_audio('api_example/1a-ODBWMUAE.wav')
query = 'Water cascades down from a waterfall.'
prediction = model.predict(query, audio)
print(prediction)
```
Run `python api_example/demo.py` (MR-HD) or `python api_example/amr_demo.py` (AMR) to reproduce the results. It automatically downloads pre-trained weights.
If you want to use other models, download [pre-trained weights](https://drive.google.com/file/d/1jxs_bvwttXTF9Lk3aKLohkqfYOonLyrO/view?usp=sharing).
When using `clip_slowfast` features, it is necessary to download [slowfast pre-trained weights](https://dl.fbaipublicfiles.com/pyslowfast/model_zoo/kinetics400/SLOWFAST_8x8_R50.pkl).
When using `clip_slowfast_pann` features, in addition to the slowfast weight, download [panns weights](https://zenodo.org/record/3987831/files/Cnn14_mAP%3D0.431.pth).
Expand Down Expand Up @@ -95,6 +109,7 @@ Moment retrieval & highlight detection
- [x] : [UVCOM (Xiao et al. CVPR24)](https://arxiv.org/abs/2311.16464)
- [x] : [TR-DETR (Sun et al. AAAI24)](https://arxiv.org/abs/2401.02309)
- [x] : [TaskWeave (Jin et al. CVPR24)](https://arxiv.org/abs/2404.09263)
- [ ] : [R2-Tuning (Liu et al. ECCV24)](https://arxiv.org/abs/2404.00801)

### Datasets
Moment retrieval & highlight detection
Expand Down
4 changes: 2 additions & 2 deletions api_example/amr_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def load_weights(weight_dir: str) -> None:
model: QDDETRPredictor = QDDETRPredictor(weight_path, device=device, feature_name='clap')

# encode audio features
model.encode_audio('api_example/1a-ODBWMUAE.wav')
audio = model.encode_audio('api_example/1a-ODBWMUAE.wav')

# moment retrieval
query: str = 'Water cascades down from a waterfall.'
prediction: Optional[Dict[str, List[float]]] = model.predict(query)
prediction: Optional[Dict[str, List[float]]] = model.predict(query, audio)
print(prediction)
8 changes: 4 additions & 4 deletions api_example/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from typing import Dict, List, Optional

def load_weights(weight_dir: str) -> None:
if not os.path.exists(os.path.join(weight_dir, 'clip_slowfast_pann_cg_detr_qvhighlight.ckpt')):
command = 'wget -P gradio_demo/weights/ https://zenodo.org/records/13960580/files/clip_slowfast_pann_cg_detr_qvhighlight.ckpt'
if not os.path.exists(os.path.join(weight_dir, 'clip_slowfast_cg_detr_qvhighlight.ckpt')):
command = 'wget -P gradio_demo/weights/ https://zenodo.org/records/13960580/files/clip_slowfast_cg_detr_qvhighlight.ckpt'
subprocess.run(command, shell=True)

if not os.path.exists('SLOWFAST_8x8_R50.pkl'):
Expand All @@ -40,9 +40,9 @@ def load_weights(weight_dir: str) -> None:
slowfast_path='SLOWFAST_8x8_R50.pkl', pann_path=None)

# encode video features
model.encode_video('api_example/RoripwjYFp8_60.0_210.0.mp4')
video = model.encode_video('api_example/RoripwjYFp8_60.0_210.0.mp4')

# moment retrieval & highlight detection
query: str = 'A woman wearing a glass is speaking in front of the camera'
prediction: Optional[Dict[str, List[float]]] = model.predict(query)
prediction: Optional[Dict[str, List[float]]] = model.predict(query, video)
print(prediction)
12 changes: 8 additions & 4 deletions gradio_demo/amr_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,20 @@ def flatten(array2d):
"""
load_pretrained_weights()
model = QDDETRPredictor('gradio_demo/weights/clap_qd_detr_clotho-moment.ckpt', device=device, feature_name='clap')
loaded_audio = None

"""
Gradio functions
"""
def audio_upload(audio):
global loaded_audio
if audio is None:
model.audio_feats = None
loaded_audio = None
yield gr.update(value="Removed the audio", visible=True)
else:
yield gr.update(value="Processing the audio. Wait for a minute...", visible=True)
model.encode_audio(audio)
audio_feats = model.encode_audio(audio)
loaded_audio = audio_feats
yield gr.update(value="Finished audio processing!", visible=True)

def model_load(radio):
Expand All @@ -85,10 +88,11 @@ def model_load(radio):
yield gr.update(value="Model loaded: {}".format(radio), visible=True)

def predict(textbox, line, gallery):
prediction = model.predict(textbox)
if prediction is None:
global loaded_audio
if loaded_audio is None:
raise gr.Error('Upload the audio before pushing the `Retrieve moment` button.')
else:
prediction = model.predict(textbox, loaded_audio)
mr_results = prediction['pred_relevant_windows']

buttons = []
Expand Down
28 changes: 19 additions & 9 deletions gradio_demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def flatten(array2d):
load_pretrained_weights()
model = CGDETRPredictor('gradio_demo/weights/clip_cg_detr_qvhighlight.ckpt', device=device,
feature_name='clip', slowfast_path=None, pann_path=None)
loaded_video = None
loaded_video_path = None

js_codes = ["""() => {{
let moment_text = document.getElementById('result_{}').textContent;
Expand All @@ -79,17 +81,19 @@ def flatten(array2d):
Gradio functions
"""
def video_upload(video):
global loaded_video, loaded_video_path
if video is None:
model.video_feats = None
model.video_mask = None
model.video_path = None
loaded_video = None
loaded_video_path = video
yield gr.update(value="Removed the video", visible=True)
else:
yield gr.update(value="Processing the video. Wait for a minute...", visible=True)
model.encode_video(video)
loaded_video = model.encode_video(video)
loaded_video_path = video
yield gr.update(value="Finished video processing!", visible=True)

def model_load(radio, video):
global loaded_video, loaded_video_path
if radio is not None:
loading_msg = "Loading new model. Wait for a minute..."
yield gr.update(value=loading_msg, visible=True), gr.update(value=loading_msg, visible=True)
Expand Down Expand Up @@ -121,15 +125,21 @@ def model_load(radio, video):
yield gr.update(value=load_finished_msg, visible=True), gr.update(value=encode_process_msg, visible=True)

if video is not None:
model.encode_video(video)
loaded_video = model.encode_video(video)
loaded_video_path = video
encode_finished_msg = "Finished video processing!"
yield gr.update(value=load_finished_msg, visible=True), gr.update(value=encode_finished_msg, visible=True)
else:
loaded_video = None
loaded_video_path = None

def predict(textbox, line, gallery):
prediction = model.predict(textbox)
if prediction is None:
raise gr.Error('Upload the video before pushing the `Retrieve moment & highlight detection` button.')
global loaded_video, loaded_video_path
if loaded_video is None:
raise gr.Error("Upload the video before pushing the `Retrieve moment & highlight detection` button.")
else:
prediction = model.predict(textbox, loaded_video)

mr_results = prediction['pred_relevant_windows']
hl_results = prediction['pred_saliency_scores']

Expand All @@ -154,7 +164,7 @@ def predict(textbox, line, gallery):
output_path = "gradio_demo/highlight_frames/highlight_{}.png".format(i)
(
ffmpeg
.input(model._video_path, ss=second)
.input(loaded_video_path, ss=second)
.output(output_path, vframes=1, qscale=2)
.global_args('-loglevel', 'quiet', '-y')
.run()
Expand Down
4 changes: 2 additions & 2 deletions lighthouse/common/tr_detr_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def forward(self, src, mask, query_embed, pos_embed, saliency_proj1, video_lengt
# *Use highlight scores to suppress feature expressions in non-highlight clips
memory_local, saliency_scores = self.HD2MR(memory_local, saliency_proj1, src, video_length, mask_local, pos_embed_local)

tgt = torch.zeros(refpoint_embed.shape[0], bs, d).cuda()
tgt = torch.zeros(refpoint_embed.shape[0], bs, d).to(src.device)
hs, references = self.decoder(tgt, memory_local, memory_key_padding_mask=mask_local,
pos=pos_embed_local, refpoints_unsigmoid=refpoint_embed) # (#layers, #queries, batch_size, d)

Expand Down Expand Up @@ -880,4 +880,4 @@ def _get_activation_fn(activation):
return nn.PReLU()
if activation == "selu":
return F.selu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
78 changes: 47 additions & 31 deletions lighthouse/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,6 @@ def __init__(

self._feature_name: str = feature_name
self._model_name: str = model_name
self._video_feats: Optional[torch.Tensor] = None
self._video_mask: Optional[torch.Tensor] = None
self._video_path: Optional[str] = None
self._audio_feats: Optional[torch.Tensor] = None

def _initialize_model(
self,
Expand Down Expand Up @@ -153,34 +149,42 @@ def _normalize_and_concat_with_timestamps(

def _is_predictable(
self,
) -> bool:
if (self._video_feats is None or self._video_mask is None or self._video_path is None) and self._feature_name != 'clap':
video: Dict[str, Optional[torch.Tensor]]) -> bool:

is_vfeat = 'video_feats' in video
is_afeat = 'audio_feats' in video

if not is_vfeat and self._feature_name != 'clap':
return False
if (self._feature_name == 'clip_slowfast_pann' or self._feature_name == 'clap') and self._audio_feats is None:

if (self._feature_name == 'clip_slowfast_pann' or self._feature_name == 'clap') and not is_afeat:
return False

return True

def _prepare_batch(
self,
query_feats: torch.Tensor,
query_mask: torch.Tensor) -> Dict[str, Optional[torch.Tensor]]:
query_mask: torch.Tensor,
video: Dict[str, Optional[torch.Tensor]]
) -> Dict[str, Optional[torch.Tensor]]:

if self._model_name == 'cg_detr':
model_inputs = dict(
src_vid=self._video_feats,
src_vid_mask=self._video_mask,
src_vid=video['video_feats'],
src_vid_mask=video['video_mask'],
src_txt=query_feats,
src_txt_mask=query_mask,
src_aud=self._audio_feats,
src_aud=video['audio_feats'],
vid=None, qid=None
)
else:
model_inputs = dict(
src_vid=self._video_feats,
src_vid_mask=self._video_mask,
src_vid=video['video_feats'],
src_vid_mask=video['video_mask'],
src_txt=query_feats,
src_txt_mask=query_mask,
src_aud=self._audio_feats
src_aud=video['audio_feats']
)

if self._model_name == 'taskweave':
Expand All @@ -196,11 +200,10 @@ def _post_processing(
prob = F.softmax(outputs["pred_logits"], -1).squeeze(0).cpu()
scores = prob[:,0]
pred_spans = outputs["pred_spans"].squeeze(0).cpu()

if self._video_feats is None:
video_feats = inputs["src_vid"]
if video_feats is None:
return [], []

video_duration = self._video_feats.shape[1] * self._clip_len
video_duration = video_feats.shape[1] * self._clip_len
pred_spans = torch.clamp(span_cxw_to_xx(pred_spans) * video_duration, min=0, max=video_duration)
cur_ranked_preds = torch.cat([pred_spans, scores[:, None]], dim=1).tolist()
cur_ranked_preds = sorted(cur_ranked_preds, key=lambda x: x[2], reverse=True)
Expand All @@ -218,7 +221,7 @@ def _encode_audio(
_: torch.Tensor # mask, but not used.
audio_feats, _ = self._audio_encoder.encode(video_path)
return audio_feats

def _encode_text(
self,
query: str) -> Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -233,7 +236,7 @@ def _encode_text(
@torch.no_grad()
def encode_video(
self,
video_path: str) -> None:
video_path: str) -> Dict[str, Optional[torch.Tensor]]:
video_feats: torch.Tensor
video_mask: torch.Tensor
if self._vision_encoder is not None:
Expand All @@ -245,15 +248,19 @@ def encode_video(
if n_frames > 75:
raise ValueError('The positional embedding only support video up to 150 secs (i.e., 75 2-sec clips) in length')
timestamed_video_feats = timestamed_video_feats.unsqueeze(0)
self._video_feats = timestamed_video_feats
self._video_mask = video_mask
self._video_path = video_path
self._audio_feats = self._encode_audio(video_path)

video = {
"video_feats" : timestamed_video_feats,
"video_mask" : video_mask,
"audio_feats" : self._encode_audio(video_path)
}

return video

@torch.no_grad()
def encode_audio(
self,
audio_path: str) -> None:
audio_path: str) -> Dict[str, torch.Tensor]:
if self._audio_encoder is None:
raise ValueError('The audio encoder is not initialized.')
audio_feats: torch.Tensor
Expand All @@ -265,19 +272,27 @@ def encode_audio(
tef_st = torch.arange(0, n_frames, 1.0) / n_frames
tef_ed = tef_st + 1.0 / n_frames
tef = torch.stack([tef_st, tef_ed], dim=1).to(self._device)
self._video_feats = tef.unsqueeze(0)
self._video_mask = audio_mask

audio = {
"video_feats" : tef.unsqueeze(0),
"video_mask" : audio_mask,
"audio_feats" : audio_feats
}

return audio

@torch.no_grad()
def predict(
self,
query: str) -> Optional[Dict[str, List[float]]]:
is_predictable = self._is_predictable()
query: str,
inputs: Dict[str, Optional[torch.Tensor]]) -> Optional[Dict[str, List[float]]]:
is_predictable = self._is_predictable(inputs)
if not is_predictable:
print("Error: No encoded features found in video variable. Did you forget to call encode_video() (MR-HD) or encode_audio() (AMR)?")
return None

query_feats, query_mask = self._encode_text(query)
inputs = self._prepare_batch(query_feats, query_mask)
inputs = self._prepare_batch(query_feats, query_mask, inputs)

if self._model_name == 'taskweave':
outputs, _ = self._model(**inputs)
Expand All @@ -288,11 +303,12 @@ def predict(

if len(ranked_moments) == 0 and len(ranked_moments) == 0:
return None

prediction = {
"pred_relevant_windows": ranked_moments,
"pred_saliency_scores": saliency_scores,
}

return prediction


Expand Down
Loading