Skip to content
This repository was archived by the owner on Oct 10, 2019. It is now read-only.

Commit f1b0ad3

Browse files
committed
Domain support
1 parent 02654cd commit f1b0ad3

File tree

4 files changed

+140
-22
lines changed

4 files changed

+140
-22
lines changed

postgres-derive-internals/src/accepts.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,22 @@ use std::fmt::Write;
22

33
use enums::Variant;
44

5-
pub fn enum_body(variants: &[Variant]) -> String {
5+
pub fn enum_body(name: &str, variants: &[Variant]) -> String {
66
let mut body = String::new();
77

88
write!(body, "
9+
if type_.name() != \"{}\" {{
10+
return false;
11+
}}
12+
913
match *type_.kind() {{
1014
::postgres::types::Kind::Enum(ref variants) => {{
1115
if variants.len() != {} {{
1216
return false;
1317
}}
1418
1519
variants.iter().all(|v| {{
16-
match &**v {{", variants.len()).unwrap();
20+
match &**v {{", name, variants.len()).unwrap();
1721

1822
for variant in variants {
1923
write!(body, "

postgres-derive-internals/src/fromsql.rs

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::fmt::Write;
2-
use syn::{Body, Ident, MacroInput};
2+
use syn::{Body, Ident, MacroInput, VariantData, Field};
3+
use quote::{Tokens, ToTokens};
34

45
use accepts;
56
use enums::Variant;
@@ -13,7 +14,11 @@ pub fn expand_derive_fromsql(input: &MacroInput) -> Result<String, String> {
1314
let (accepts_body, to_sql_body) = match input.body {
1415
Body::Enum(ref variants) => {
1516
let variants: Vec<Variant> = try!(variants.iter().map(Variant::parse).collect());
16-
(accepts::enum_body(&variants), enum_body(&input.ident, &variants))
17+
(accepts::enum_body(&name, &variants), enum_body(&input.ident, &variants))
18+
}
19+
Body::Struct(VariantData::Tuple(ref fields)) if fields.len() == 1 => {
20+
let field = &fields[0];
21+
(domain_accepts_body(field), domain_body(&input.ident, field))
1722
}
1823
_ => {
1924
return Err("#[derive(ToSql)] may only be applied to structs, single field tuple \
@@ -23,22 +28,18 @@ pub fn expand_derive_fromsql(input: &MacroInput) -> Result<String, String> {
2328

2429
let out = format!("
2530
impl ::postgres::types::FromSql for {} {{
26-
fn from_sql(_: &::postgres::types::Type,
31+
fn from_sql(_type: &::postgres::types::Type,
2732
buf: &[u8],
28-
_: &::postgres::types::SessionInfo)
33+
_info: &::postgres::types::SessionInfo)
2934
-> ::std::result::Result<{},
3035
::std::boxed::Box<::std::error::Error +
3136
::std::marker::Sync +
3237
::std::marker::Send>> {{{}
3338
}}
3439
35-
fn accepts(type_: &::postgres::types::Type) -> bool {{
36-
if type_.name() != \"{}\" {{
37-
return false;
38-
}}
39-
{}
40+
fn accepts(type_: &::postgres::types::Type) -> bool {{{}
4041
}}
41-
}}", input.ident, input.ident, to_sql_body, name, accepts_body);
42+
}}", input.ident, input.ident, to_sql_body, accepts_body);
4243

4344
Ok(out)
4445
}
@@ -62,3 +63,17 @@ fn enum_body(ident: &Ident, variants: &[Variant]) -> String {
6263

6364
out
6465
}
66+
67+
fn domain_accepts_body(field: &Field) -> String {
68+
let mut tokens = Tokens::new();
69+
field.ty.to_tokens(&mut tokens);
70+
format!("
71+
<{} as ::postgres::types::FromSql>::accepts(type_)", tokens)
72+
}
73+
74+
fn domain_body(ident: &Ident, field: &Field) -> String {
75+
let mut tokens = Tokens::new();
76+
field.ty.to_tokens(&mut tokens);
77+
format!("\
78+
<{} as ::postgres::types::FromSql>::from_sql(_type, buf, _info).map({})", tokens, ident)
79+
}

postgres-derive-internals/src/tosql.rs

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::fmt::Write;
2-
use syn::{Body, Ident, MacroInput};
2+
use syn::{Body, Ident, MacroInput, VariantData, Field};
3+
use quote::{Tokens, ToTokens};
34

45
use accepts;
56
use enums::Variant;
@@ -13,7 +14,11 @@ pub fn expand_derive_tosql(input: &MacroInput) -> Result<String, String> {
1314
let (accepts_body, to_sql_body) = match input.body {
1415
Body::Enum(ref variants) => {
1516
let variants: Vec<Variant> = try!(variants.iter().map(Variant::parse).collect());
16-
(accepts::enum_body(&variants), enum_body(&input.ident, &variants))
17+
(accepts::enum_body(&name, &variants), enum_body(&input.ident, &variants))
18+
}
19+
Body::Struct(VariantData::Tuple(ref fields)) if fields.len() == 1 => {
20+
let field = &fields[0];
21+
(domain_accepts_body(&name, &field), domain_body())
1722
}
1823
_ => {
1924
return Err("#[derive(ToSql)] may only be applied to structs, single field tuple \
@@ -24,24 +29,20 @@ pub fn expand_derive_tosql(input: &MacroInput) -> Result<String, String> {
2429
let out = format!("
2530
impl ::postgres::types::ToSql for {} {{
2631
fn to_sql(&self,
27-
_: &::postgres::types::Type,
32+
_type: &::postgres::types::Type,
2833
buf: &mut ::std::vec::Vec<u8>,
29-
_: &::postgres::types::SessionInfo)
34+
_info: &::postgres::types::SessionInfo)
3035
-> ::std::result::Result<::postgres::types::IsNull,
3136
::std::boxed::Box<::std::error::Error +
3237
::std::marker::Sync +
3338
::std::marker::Send>> {{{}
3439
}}
3540
36-
fn accepts(type_: &::postgres::types::Type) -> bool {{
37-
if type_.name() != \"{}\" {{
38-
return false;
39-
}}
40-
{}
41+
fn accepts(type_: &::postgres::types::Type) -> bool {{{}
4142
}}
4243
4344
to_sql_checked!();
44-
}}", input.ident, to_sql_body, name, accepts_body);
45+
}}", input.ident, to_sql_body, accepts_body);
4546

4647
Ok(out)
4748
}
@@ -63,3 +64,30 @@ fn enum_body(ident: &Ident, variants: &[Variant]) -> String {
6364

6465
out
6566
}
67+
68+
fn domain_accepts_body(name: &str, field: &Field) -> String {
69+
let mut tokens = Tokens::new();
70+
field.ty.to_tokens(&mut tokens);
71+
72+
format!("
73+
if type_.name() != \"{}\" {{
74+
return false;
75+
}}
76+
77+
match *type_.kind() {{
78+
::postgres::types::Kind::Domain(ref type_) => {{
79+
<{} as ::postgres::types::ToSql>::accepts(type_)
80+
}}
81+
_ => false,
82+
}}", name, tokens)
83+
}
84+
85+
fn domain_body() -> String {
86+
"
87+
let type_ = match *_type.kind() {
88+
::postgres::types::Kind::Domain(ref type_) => type_,
89+
_ => unreachable!(),
90+
};
91+
92+
::postgres::types::ToSql::to_sql(&self.0, type_, buf, _info)".to_owned()
93+
}

postgres-derive/tests/domains.rs

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#![feature(rustc_macro)]
2+
3+
#[macro_use]
4+
extern crate postgres_derive;
5+
#[macro_use]
6+
extern crate postgres;
7+
8+
use postgres::{Connection, TlsMode};
9+
use postgres::error::Error;
10+
use postgres::types::WrongType;
11+
12+
mod util;
13+
14+
#[test]
15+
fn defaults() {
16+
#[derive(FromSql, ToSql, Debug, PartialEq)]
17+
struct SessionId(Vec<u8>);
18+
19+
let conn = Connection::connect("postgres://postgres@localhost", TlsMode::None).unwrap();
20+
conn.execute("CREATE DOMAIN pg_temp.\"SessionId\" AS bytea CHECK(octet_length(VALUE) = 16);",
21+
&[])
22+
.unwrap();
23+
24+
util::test_type(&conn, "\"SessionId\"", &[(SessionId(b"0123456789abcdef".to_vec()),
25+
"'0123456789abcdef'")]);
26+
}
27+
28+
#[test]
29+
fn name_overrides() {
30+
#[derive(FromSql, ToSql, Debug, PartialEq)]
31+
#[postgres(name = "session_id")]
32+
struct SessionId(Vec<u8>);
33+
34+
let conn = Connection::connect("postgres://postgres@localhost", TlsMode::None).unwrap();
35+
conn.execute("CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16);", &[])
36+
.unwrap();
37+
38+
util::test_type(&conn, "session_id", &[(SessionId(b"0123456789abcdef".to_vec()),
39+
"'0123456789abcdef'")]);
40+
}
41+
42+
#[test]
43+
fn wrong_name() {
44+
#[derive(FromSql, ToSql, Debug, PartialEq)]
45+
struct SessionId(Vec<u8>);
46+
47+
let conn = Connection::connect("postgres://postgres@localhost", TlsMode::None).unwrap();
48+
conn.execute("CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16);", &[])
49+
.unwrap();
50+
51+
match conn.execute("SELECT $1::session_id", &[&SessionId(vec![])]) {
52+
Err(Error::Conversion(ref r)) if r.is::<WrongType>() => {}
53+
v => panic!("unexpected response {:?}", v),
54+
}
55+
}
56+
57+
#[test]
58+
fn wrong_type() {
59+
#[derive(FromSql, ToSql, Debug, PartialEq)]
60+
#[postgres(name = "session_id")]
61+
struct SessionId(i32);
62+
63+
let conn = Connection::connect("postgres://postgres@localhost", TlsMode::None).unwrap();
64+
conn.execute("CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16);", &[])
65+
.unwrap();
66+
67+
match conn.execute("SELECT $1::session_id", &[&SessionId(0)]) {
68+
Err(Error::Conversion(ref r)) if r.is::<WrongType>() => {}
69+
v => panic!("unexpected response {:?}", v),
70+
}
71+
}

0 commit comments

Comments
 (0)