Skip to content

Commit 01c120f

Browse files
committed
refactor int decoding
1 parent 8733269 commit 01c120f

File tree

5 files changed

+96
-32
lines changed

5 files changed

+96
-32
lines changed

CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## 0.6.40
9+
10+
- mssql: Allow decoding various numeric types as i16
11+
812
## 0.6.39
913

1014
- Fix `COPY` error handling in Postgres

Cargo.lock

+3-3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ serde = { version = "1.0.132", features = ["derive"] }
144144
serde_json = "1.0.73"
145145
url = "2.2.2"
146146
rand = "0.9"
147-
rand_xoshiro = "0.6.0"
147+
rand_xoshiro = "0.7.0"
148148
hex = "0.4.3"
149149
tempdir = "0.3.7"
150150
# Needed to test SQLCipher

sqlx-core/src/mssql/types/int.rs

+76-28
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use byteorder::{ByteOrder, LittleEndian};
2+
use std::any::type_name;
23

34
use crate::decode::Decode;
45
use crate::encode::{Encode, IsNull};
@@ -27,7 +28,7 @@ impl Encode<'_, Mssql> for i8 {
2728

2829
impl Decode<'_, Mssql> for i8 {
2930
fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
30-
Ok(i8::from_le_bytes(value.as_bytes()?[0..1].try_into()?))
31+
decode_integer(value)
3132
}
3233
}
3334

@@ -37,7 +38,10 @@ impl Type<Mssql> for i16 {
3738
}
3839

3940
fn compatible(ty: &MssqlTypeInfo) -> bool {
40-
matches!(ty.0.ty, DataType::SmallInt | DataType::IntN) && ty.0.size == 2
41+
matches!(
42+
ty.0.ty,
43+
DataType::TinyInt | DataType::SmallInt | DataType::Int | DataType::IntN
44+
) && ty.0.size <= 2
4145
}
4246
}
4347

@@ -51,7 +55,7 @@ impl Encode<'_, Mssql> for i16 {
5155

5256
impl Decode<'_, Mssql> for i16 {
5357
fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
54-
Ok(LittleEndian::read_i16(value.as_bytes()?))
58+
decode_integer(value)
5559
}
5660
}
5761

@@ -75,7 +79,7 @@ impl Encode<'_, Mssql> for i32 {
7579

7680
impl Decode<'_, Mssql> for i32 {
7781
fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
78-
Ok(LittleEndian::read_i32(value.as_bytes()?))
82+
decode_integer(value)
7983
}
8084
}
8185

@@ -110,30 +114,7 @@ impl Encode<'_, Mssql> for i64 {
110114

111115
impl Decode<'_, Mssql> for i64 {
112116
fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
113-
let ty = value.type_info.0.ty;
114-
let precision = value.type_info.0.precision;
115-
let scale = value.type_info.0.scale;
116-
match ty {
117-
DataType::SmallInt
118-
| DataType::Int
119-
| DataType::TinyInt
120-
| DataType::BigInt
121-
| DataType::IntN => {
122-
let mut buf = [0u8; 8];
123-
let bytes_val = value.as_bytes()?;
124-
buf[..bytes_val.len()].copy_from_slice(bytes_val);
125-
Ok(i64::from_le_bytes(buf))
126-
}
127-
DataType::Numeric | DataType::NumericN | DataType::Decimal | DataType::DecimalN => {
128-
decode_numeric(value.as_bytes()?, precision, scale)
129-
}
130-
_ => Err(err_protocol!(
131-
"Decoding {:?} as a float failed because type {:?} is not implemented",
132-
value,
133-
ty
134-
)
135-
.into()),
136-
}
117+
decode_integer(value)
137118
}
138119
}
139120

@@ -150,3 +131,70 @@ fn decode_numeric(bytes: &[u8], _precision: u8, mut scale: u8) -> Result<i64, Bo
150131
let n = i64::try_from(numerator)?;
151132
Ok(n * if negative { -1 } else { 1 })
152133
}
134+
135+
fn decode_integer<T>(value: MssqlValueRef<'_>) -> Result<T, BoxDynError>
136+
where
137+
T: TryFrom<i64>,
138+
T::Error: std::error::Error + Send + Sync + 'static,
139+
{
140+
let ty = value.type_info.0.ty;
141+
let precision = value.type_info.0.precision;
142+
let scale = value.type_info.0.scale;
143+
144+
let type_name = type_name::<T>();
145+
146+
match ty {
147+
DataType::SmallInt
148+
| DataType::Int
149+
| DataType::TinyInt
150+
| DataType::BigInt
151+
| DataType::IntN => {
152+
let mut buf = [0u8; 8];
153+
let bytes_val = value.as_bytes()?;
154+
let len = bytes_val.len();
155+
156+
if len > buf.len() {
157+
return Err(err_protocol!(
158+
"Decoding {:?} as a {} failed because type {:?} has more than {} bytes",
159+
value,
160+
type_name,
161+
ty,
162+
buf.len()
163+
)
164+
.into());
165+
}
166+
167+
buf[..len].copy_from_slice(&bytes_val);
168+
169+
let i64_val = i64::from_le_bytes(buf);
170+
T::try_from(i64_val).map_err(|_| {
171+
err_protocol!(
172+
"Decoding {:?} as a {} failed because value {} is out of range",
173+
value,
174+
type_name,
175+
i64_val
176+
)
177+
.into()
178+
})
179+
}
180+
DataType::Numeric | DataType::NumericN | DataType::Decimal | DataType::DecimalN => {
181+
let n = decode_numeric(value.as_bytes()?, precision, scale)?;
182+
T::try_from(n).map_err(|_| {
183+
err_protocol!(
184+
"Decoding {:?} as a {} failed because value {} is out of range",
185+
value,
186+
type_name,
187+
n
188+
)
189+
.into()
190+
})
191+
}
192+
_ => Err(err_protocol!(
193+
"Decoding {:?} as a {} failed because type {:?} is not implemented",
194+
value,
195+
type_name,
196+
ty
197+
)
198+
.into()),
199+
}
200+
}

tests/mssql/mssql.rs

+12
Original file line numberDiff line numberDiff line change
@@ -489,3 +489,15 @@ async fn test_pool_callbacks() -> anyhow::Result<()> {
489489

490490
Ok(())
491491
}
492+
493+
#[sqlx_macros::test]
494+
async fn it_can_decode_tinyint_as_i16() -> anyhow::Result<()> {
495+
let mut conn = new::<Mssql>().await?;
496+
497+
let row: MssqlRow = conn.fetch_one("SELECT CAST(42 AS TINYINT) as val").await?;
498+
let v: i16 = row.try_get("val")?;
499+
500+
assert_eq!(v, 42);
501+
502+
Ok(())
503+
}

0 commit comments

Comments
 (0)