Skip to content

Commit 38ff44f

Browse files
ravenousengxson
andcommitted
Refactor JANUS_PRO handling in clip.cpp
Co-authored-by: Xuan-Son Nguyen <[email protected]>
1 parent 9601dc8 commit 38ff44f

File tree

1 file changed

+10
-30
lines changed

1 file changed

+10
-30
lines changed

tools/mtmd/clip.cpp

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,15 @@ struct clip_graph {
550550
cur = ggml_gelu(ctx0, cur);
551551
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
552552
cur = ggml_add(ctx0, cur, model.mm_2_b);
553+
554+
} else if (ctx->proj_type() == PROJECTOR_TYPE_JANUS_PRO) {
555+
cur = build_ffn(cur,
556+
model.mm_0_w, model.mm_0_b,
557+
nullptr, nullptr,
558+
model.mm_1_w, model.mm_1_b,
559+
hparams.ffn_op,
560+
-1);
561+
553562
} else {
554563
GGML_ABORT("SigLIP: Unsupported projector type");
555564
}
@@ -1508,35 +1517,6 @@ struct clip_graph {
15081517

15091518
return gf;
15101519
}
1511-
1512-
ggml_cgraph * build_janus_pro() {
1513-
GGML_ASSERT(model.class_embedding == nullptr); // No CLS token
1514-
1515-
ggml_tensor * inp = build_inp();
1516-
1517-
ggml_tensor * learned_pos_embd = model.position_embeddings;
1518-
1519-
ggml_tensor * cur = build_vit(
1520-
inp, n_patches,
1521-
NORM_TYPE_NORMAL,
1522-
hparams.ffn_op,
1523-
learned_pos_embd,
1524-
nullptr);
1525-
1526-
cur = build_ffn(cur,
1527-
model.mm_0_w, model.mm_0_b,
1528-
nullptr, nullptr,
1529-
model.mm_1_w, model.mm_1_b,
1530-
hparams.ffn_op,
1531-
-1);
1532-
cb(cur, "aligner_1", -1);
1533-
1534-
// build the graph
1535-
ggml_build_forward_expand(gf, cur);
1536-
1537-
return gf;
1538-
}
1539-
15401520
// whisper encoder with custom projector
15411521
ggml_cgraph * build_whisper_enc() {
15421522
const int n_frames = img.nx;
@@ -2156,7 +2136,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
21562136
} break;
21572137
case PROJECTOR_TYPE_JANUS_PRO:
21582138
{
2159-
res = graph.build_janus_pro();
2139+
res = graph.build_siglip();
21602140
} break;
21612141
default:
21622142
{

0 commit comments

Comments
 (0)