Skip to content

Commit f1f1966

Browse files
committed
Implement a first version of modular channel decoding.
Does not support the Weighted predictor or prev-channel properties yet, and is not optimized.
1 parent cf1d5fb commit f1f1966

File tree

7 files changed

+480
-36
lines changed

7 files changed

+480
-36
lines changed

jxl/src/frame/modular.rs

+64-25
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::fmt::Debug;
77

88
use crate::{
99
bit_reader::BitReader,
10-
error::{Error, Result},
10+
error::Result,
1111
headers::{
1212
extra_channels::ExtraChannelInfo, frame_header::FrameHeader, modular::GroupHeader,
1313
JxlHeader,
@@ -16,10 +16,12 @@ use crate::{
1616
util::{tracing_wrappers::*, CeilLog2},
1717
};
1818

19+
mod decode;
1920
mod predict;
2021
mod transforms;
2122
mod tree;
2223

24+
use decode::{decode_modular_section, ModularStreamId};
2325
pub use predict::Predictor;
2426
use transforms::{make_grids, TransformStepChunk};
2527
pub use tree::Tree;
@@ -63,9 +65,9 @@ impl ChannelInfo {
6365
enum ModularGridKind {
6466
// Single big channel.
6567
None,
66-
// 2048x2048 image-pixels.
68+
// 2048x2048 image-pixels (if modular_group_shift == 1).
6769
Lf,
68-
// 256x256 image-pixels.
70+
// 256x256 image-pixels (if modular_group_shift == 1).
6971
Hf,
7072
}
7173

@@ -78,6 +80,7 @@ struct ModularBuffer {
7880
auxiliary_data: Option<Image<i32>>,
7981
remaining_uses: usize,
8082
used_by_transforms: Vec<usize>,
83+
size: (usize, usize),
8184
}
8285

8386
#[allow(dead_code)]
@@ -144,50 +147,78 @@ impl FullModularImage {
144147
});
145148
}
146149

150+
if channels.is_empty() {
151+
return Ok(Self {
152+
buffer_info: vec![],
153+
transform_steps: vec![],
154+
section_buffer_indices: vec![vec![]; 2 + frame_header.passes.num_passes as usize],
155+
});
156+
}
157+
147158
trace!("reading modular header");
148159
let header = GroupHeader::read(br)?;
149160

150-
if header.use_global_tree && global_tree.is_none() {
151-
return Err(Error::NoGlobalTree);
152-
}
153-
154161
let (mut buffer_info, transform_steps) =
155162
transforms::meta_apply_transforms(&channels, &header.transforms)?;
156163

157164
// Assign each (channel, group) pair present in the bitstream to the section in which it will be decoded.
158165
let mut section_buffer_indices: Vec<Vec<usize>> = vec![];
159166

167+
let mut sorted_buffers: Vec<_> = buffer_info
168+
.iter()
169+
.enumerate()
170+
.filter_map(|(i, b)| {
171+
if b.is_coded {
172+
Some((b.channel_id, i))
173+
} else {
174+
None
175+
}
176+
})
177+
.collect();
178+
179+
sorted_buffers.sort_by_key(|x| x.0);
180+
160181
section_buffer_indices.push(
161-
buffer_info
182+
sorted_buffers
162183
.iter()
163-
.enumerate()
164-
.filter(|x| x.1.is_coded)
165-
.take_while(|x| x.1.info.is_meta_or_small(frame_header.group_dim()))
166-
.map(|x| x.0)
184+
.take_while(|x| {
185+
buffer_info[x.1]
186+
.info
187+
.is_meta_or_small(frame_header.group_dim())
188+
})
189+
.map(|x| x.1)
167190
.collect(),
168191
);
169192

170193
section_buffer_indices.push(
171-
buffer_info
194+
sorted_buffers
172195
.iter()
173-
.enumerate()
174-
.filter(|x| x.1.is_coded)
175-
.skip_while(|x| x.1.info.is_meta_or_small(frame_header.group_dim()))
176-
.filter(|x| x.1.info.is_shift_in_range(3, usize::MAX))
177-
.map(|x| x.0)
196+
.skip_while(|x| {
197+
buffer_info[x.1]
198+
.info
199+
.is_meta_or_small(frame_header.group_dim())
200+
})
201+
.filter(|x| buffer_info[x.1].info.is_shift_in_range(3, usize::MAX))
202+
.map(|x| x.1)
178203
.collect(),
179204
);
180205

181206
for pass in 0..frame_header.passes.num_passes as usize {
182207
let (min_shift, max_shift) = frame_header.passes.downsampling_bracket(pass);
183208
section_buffer_indices.push(
184-
buffer_info
209+
sorted_buffers
185210
.iter()
186-
.enumerate()
187-
.filter(|x| x.1.is_coded)
188-
.filter(|x| !x.1.info.is_meta_or_small(frame_header.group_dim()))
189-
.filter(|x| x.1.info.is_shift_in_range(min_shift, max_shift))
190-
.map(|x| x.0)
211+
.filter(|x| {
212+
!buffer_info[x.1]
213+
.info
214+
.is_meta_or_small(frame_header.group_dim())
215+
})
216+
.filter(|x| {
217+
buffer_info[x.1]
218+
.info
219+
.is_shift_in_range(min_shift, max_shift)
220+
})
221+
.map(|x| x.1)
191222
.collect(),
192223
);
193224
}
@@ -206,7 +237,15 @@ impl FullModularImage {
206237
&mut buffer_info,
207238
);
208239

209-
// TODO(veluca93): read global channels
240+
decode_modular_section(
241+
&mut buffer_info,
242+
&section_buffer_indices[0],
243+
0,
244+
ModularStreamId::GlobalData.get_id(frame_header),
245+
&header,
246+
global_tree,
247+
br,
248+
)?;
210249

211250
Ok(FullModularImage {
212251
buffer_info,

jxl/src/frame/modular/decode.rs

+156
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2+
//
3+
// Use of this source code is governed by a BSD-style
4+
// license that can be found in the LICENSE file.
5+
6+
use crate::{
7+
bit_reader::BitReader,
8+
entropy_coding::decode::Reader,
9+
error::{Error, Result},
10+
frame::quantizer::NUM_QUANT_TABLES,
11+
headers::{frame_header::FrameHeader, modular::GroupHeader},
12+
image::Image,
13+
util::tracing_wrappers::*,
14+
};
15+
16+
use super::{
17+
predict::WeightedPredictorState, tree::NUM_NONREF_PROPERTIES, ModularBufferInfo, Tree,
18+
};
19+
20+
#[allow(unused)]
21+
pub enum ModularStreamId {
22+
GlobalData,
23+
VarDCTLF(usize),
24+
ModularLF(usize),
25+
LFMeta(usize),
26+
QuantTable(usize),
27+
ModularHF { pass: usize, group: usize },
28+
}
29+
30+
impl ModularStreamId {
31+
pub fn get_id(&self, frame_header: &FrameHeader) -> usize {
32+
match self {
33+
Self::GlobalData => 0,
34+
Self::VarDCTLF(g) => 1 + g,
35+
Self::ModularLF(g) => 1 + frame_header.num_lf_groups() + g,
36+
Self::LFMeta(g) => 1 + frame_header.num_lf_groups() * 2 + g,
37+
Self::QuantTable(q) => 1 + frame_header.num_lf_groups() * 3 + q,
38+
Self::ModularHF { pass, group } => {
39+
1 + frame_header.num_lf_groups() * 3
40+
+ NUM_QUANT_TABLES
41+
+ frame_header.num_groups() * *pass
42+
+ *group
43+
}
44+
}
45+
}
46+
}
47+
48+
#[allow(clippy::too_many_arguments)]
49+
#[instrument(level = "debug", skip(buffers, reader, br))]
50+
fn decode_modular_channel(
51+
buffers: &mut [ModularBufferInfo],
52+
buffer_indices: &[usize],
53+
index: usize,
54+
grid_index: usize,
55+
stream_id: usize,
56+
header: &GroupHeader,
57+
tree: &Tree,
58+
reader: &mut Reader,
59+
br: &mut BitReader,
60+
) -> Result<()> {
61+
debug!("reading channel");
62+
let size = {
63+
let b = &mut buffers[buffer_indices[index]].buffer_grid[grid_index];
64+
if b.data.is_none() {
65+
b.data = Some(Image::new(b.size)?)
66+
}
67+
b.size
68+
};
69+
70+
let chan = buffers[buffer_indices[index]].channel_id;
71+
let mut wp_state = WeightedPredictorState::new(header);
72+
for y in 0..size.1 {
73+
let mut property_buffer = [0; 256];
74+
property_buffer[0] = chan as i32;
75+
property_buffer[1] = stream_id as i32;
76+
for x in 0..size.0 {
77+
let prediction_result = tree.predict(
78+
buffers,
79+
buffer_indices,
80+
index,
81+
grid_index,
82+
&mut wp_state,
83+
x,
84+
y,
85+
&mut property_buffer,
86+
);
87+
let dec = reader.read_signed(br, prediction_result.context as usize)?;
88+
let val =
89+
prediction_result.guess + (prediction_result.multiplier as i64) * (dec as i64);
90+
buffers[buffer_indices[index]].buffer_grid[grid_index]
91+
.data
92+
.as_mut()
93+
.unwrap()
94+
.as_rect_mut()
95+
.row(y)[x] = val as i32;
96+
trace!(y, x, val, dec, ?property_buffer, ?prediction_result);
97+
// TODO(veluca): update WP errors.
98+
}
99+
}
100+
101+
Ok(())
102+
}
103+
104+
pub fn decode_modular_section(
105+
buffers: &mut [ModularBufferInfo],
106+
buffer_indices: &[usize],
107+
grid_index: usize,
108+
stream_id: usize,
109+
header: &GroupHeader,
110+
global_tree: &Option<Tree>,
111+
br: &mut BitReader,
112+
) -> Result<()> {
113+
if buffer_indices.is_empty() {
114+
return Ok(());
115+
}
116+
if header.use_global_tree && global_tree.is_none() {
117+
return Err(Error::NoGlobalTree);
118+
}
119+
let local_tree = if !header.use_global_tree {
120+
Some(Tree::read(br, 1024)?)
121+
} else {
122+
None
123+
};
124+
let tree = if header.use_global_tree {
125+
global_tree.as_ref().unwrap()
126+
} else {
127+
local_tree.as_ref().unwrap()
128+
};
129+
130+
if tree.max_property() >= NUM_NONREF_PROPERTIES - 2 {
131+
todo!(
132+
"WP and reference properties are not implemented yet, max property: {}",
133+
tree.max_property()
134+
);
135+
}
136+
137+
let mut reader = tree.histograms.make_reader(br)?;
138+
139+
for i in 0..buffer_indices.len() {
140+
decode_modular_channel(
141+
buffers,
142+
buffer_indices,
143+
i,
144+
grid_index,
145+
stream_id,
146+
header,
147+
tree,
148+
&mut reader,
149+
br,
150+
)?;
151+
}
152+
153+
reader.check_final_state()?;
154+
155+
Ok(())
156+
}

jxl/src/frame/modular/predict.rs

+19-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
// Use of this source code is governed by a BSD-style
44
// license that can be found in the LICENSE file.
55

6-
use crate::error::{Error, Result};
6+
use crate::{
7+
error::{Error, Result},
8+
headers::modular::GroupHeader,
9+
};
710
use num_derive::FromPrimitive;
811
use num_traits::FromPrimitive;
912

@@ -37,3 +40,18 @@ impl TryFrom<u32> for Predictor {
3740
impl Predictor {
3841
pub const NUM_PREDICTORS: u32 = Predictor::AverageAll as u32 + 1;
3942
}
43+
44+
#[derive(Debug)]
45+
pub struct WeightedPredictorState;
46+
47+
impl WeightedPredictorState {
48+
pub fn new(_header: &GroupHeader) -> Self {
49+
// TODO(veluca): implement the weighted predictor.
50+
Self
51+
}
52+
53+
pub fn predict_and_property(&self) -> (i64, i32) {
54+
// TODO(veluca): implement the weighted predictor.
55+
(0, 0)
56+
}
57+
}

jxl/src/frame/modular/transforms.rs

+26-5
Original file line numberDiff line numberDiff line change
@@ -463,11 +463,32 @@ pub fn make_grids(
463463
// Create grids.
464464
for g in buffer_info.iter_mut() {
465465
g.buffer_grid = get_grid_indices(g.grid_kind)
466-
.map(|_| ModularBuffer {
467-
data: None,
468-
auxiliary_data: None,
469-
remaining_uses: 0,
470-
used_by_transforms: vec![],
466+
.map(|(x, y)| {
467+
let chan_size = g.info.size;
468+
let size = match g.grid_kind {
469+
ModularGridKind::None => chan_size,
470+
ModularGridKind::Lf => {
471+
let dx = frame_header.lf_group_dim() >> g.info.shift.unwrap().0;
472+
let bx = x as usize * dx;
473+
let dy = frame_header.lf_group_dim() >> g.info.shift.unwrap().1;
474+
let by = y as usize * dy;
475+
((chan_size.0 - bx).min(dx), (chan_size.1 - by).min(dy))
476+
}
477+
ModularGridKind::Hf => {
478+
let dx = frame_header.group_dim() >> g.info.shift.unwrap().0;
479+
let bx = x as usize * dx;
480+
let dy = frame_header.group_dim() >> g.info.shift.unwrap().1;
481+
let by = y as usize * dy;
482+
((chan_size.0 - bx).min(dx), (chan_size.1 - by).min(dy))
483+
}
484+
};
485+
ModularBuffer {
486+
data: None,
487+
auxiliary_data: None,
488+
remaining_uses: 0,
489+
used_by_transforms: vec![],
490+
size,
491+
}
471492
})
472493
.collect();
473494
}

0 commit comments

Comments
 (0)