Skip to content

Commit 816f0bd

Browse files
author
Oliver Scherer
committed
Implement derives for generic wrapper types
1 parent bafa54c commit 816f0bd

File tree

2 files changed

+154
-15
lines changed

2 files changed

+154
-15
lines changed

src/lib.rs

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,16 @@ fn newtype_inner(data: &syn::Data) -> Option<syn::Type> {
165165
pub fn from_primitive(input: TokenStream) -> TokenStream {
166166
let ast: syn::DeriveInput = syn::parse(input).unwrap();
167167
let name = &ast.ident;
168+
let (impl_, type_, where_) = &ast.generics.split_for_impl();
168169

169170
let impl_ = if let Some(inner_ty) = newtype_inner(&ast.data) {
171+
let bound = quote! { #inner_ty: _num_traits::FromPrimitive };
172+
let where_ = match where_ {
173+
Some(where_) => quote!{ #where_, #bound },
174+
None => quote!{ where #bound },
175+
};
170176
quote! {
171-
impl _num_traits::FromPrimitive for #name {
177+
impl #impl_ _num_traits::FromPrimitive for #name #type_ #where_ {
172178
fn from_i64(n: i64) -> Option<Self> {
173179
<#inner_ty as _num_traits::FromPrimitive>::from_i64(n).map(#name)
174180
}
@@ -251,7 +257,7 @@ pub fn from_primitive(input: TokenStream) -> TokenStream {
251257
};
252258

253259
quote! {
254-
impl _num_traits::FromPrimitive for #name {
260+
impl #impl_ _num_traits::FromPrimitive for #name #type_ #where_ {
255261
#[allow(trivial_numeric_casts)]
256262
fn from_i64(#from_i64_var: i64) -> Option<Self> {
257263
#(#clauses else)* {
@@ -321,10 +327,16 @@ pub fn from_primitive(input: TokenStream) -> TokenStream {
321327
pub fn to_primitive(input: TokenStream) -> TokenStream {
322328
let ast: syn::DeriveInput = syn::parse(input).unwrap();
323329
let name = &ast.ident;
330+
let (impl_, type_, where_) = &ast.generics.split_for_impl();
324331

325332
let impl_ = if let Some(inner_ty) = newtype_inner(&ast.data) {
333+
let bound = quote! { #inner_ty: _num_traits::ToPrimitive };
334+
let where_ = match where_ {
335+
Some(where_) => quote!{ #where_, #bound },
336+
None => quote!{ where #bound },
337+
};
326338
quote! {
327-
impl _num_traits::ToPrimitive for #name {
339+
impl #impl_ _num_traits::ToPrimitive for #name #type_ #where_ {
328340
fn to_i64(&self) -> Option<i64> {
329341
<#inner_ty as _num_traits::ToPrimitive>::to_i64(&self.0)
330342
}
@@ -410,7 +422,7 @@ pub fn to_primitive(input: TokenStream) -> TokenStream {
410422
};
411423

412424
quote! {
413-
impl _num_traits::ToPrimitive for #name {
425+
impl #impl_ _num_traits::ToPrimitive for #name #type_ #where_ {
414426
#[allow(trivial_numeric_casts)]
415427
fn to_i64(&self) -> Option<i64> {
416428
#match_expr
@@ -440,36 +452,41 @@ const NEWTYPE_ONLY: &str = "This trait can only be derived for newtypes";
440452
pub fn num_ops(input: TokenStream) -> TokenStream {
441453
let ast: syn::DeriveInput = syn::parse(input).unwrap();
442454
let name = &ast.ident;
455+
let (impl_, type_, where_) = &ast.generics.split_for_impl();
443456
let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY);
457+
let where_ = match where_ {
458+
Some(where_) => quote!{ #where_, },
459+
None => quote!{ where },
460+
};
444461
dummy_const_trick(
445462
"NumOps",
446463
&name,
447464
quote! {
448-
impl ::std::ops::Add for #name {
465+
impl #impl_ ::std::ops::Add for #name #type_ #where_ #inner_ty: ::std::ops::Add<Output = #inner_ty> {
449466
type Output = Self;
450467
fn add(self, other: Self) -> Self {
451468
#name(<#inner_ty as ::std::ops::Add>::add(self.0, other.0))
452469
}
453470
}
454-
impl ::std::ops::Sub for #name {
471+
impl #impl_ ::std::ops::Sub for #name #type_ #where_ #inner_ty: ::std::ops::Sub<Output = #inner_ty> {
455472
type Output = Self;
456473
fn sub(self, other: Self) -> Self {
457474
#name(<#inner_ty as ::std::ops::Sub>::sub(self.0, other.0))
458475
}
459476
}
460-
impl ::std::ops::Mul for #name {
477+
impl #impl_ ::std::ops::Mul for #name #type_ #where_ #inner_ty: ::std::ops::Mul<Output = #inner_ty> {
461478
type Output = Self;
462479
fn mul(self, other: Self) -> Self {
463480
#name(<#inner_ty as ::std::ops::Mul>::mul(self.0, other.0))
464481
}
465482
}
466-
impl ::std::ops::Div for #name {
483+
impl #impl_ ::std::ops::Div for #name #type_ #where_ #inner_ty: ::std::ops::Div<Output = #inner_ty> {
467484
type Output = Self;
468485
fn div(self, other: Self) -> Self {
469486
#name(<#inner_ty as ::std::ops::Div>::div(self.0, other.0))
470487
}
471488
}
472-
impl ::std::ops::Rem for #name {
489+
impl #impl_ ::std::ops::Rem for #name #type_ #where_ #inner_ty: ::std::ops::Rem<Output = #inner_ty> {
473490
type Output = Self;
474491
fn rem(self, other: Self) -> Self {
475492
#name(<#inner_ty as ::std::ops::Rem>::rem(self.0, other.0))
@@ -488,13 +505,20 @@ pub fn num_ops(input: TokenStream) -> TokenStream {
488505
pub fn num_cast(input: TokenStream) -> TokenStream {
489506
let ast: syn::DeriveInput = syn::parse(input).unwrap();
490507
let name = &ast.ident;
508+
let (impl_, type_, where_) = &ast.generics.split_for_impl();
491509
let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY);
510+
let where_ = match where_ {
511+
Some(where_) => quote!{ #where_, },
512+
None => quote!{ where },
513+
};
514+
let fn_param = proc_macro2::Ident::new("FROM_T", name.span());
492515
dummy_const_trick(
493516
"NumCast",
494517
&name,
495518
quote! {
496-
impl _num_traits::NumCast for #name {
497-
fn from<T: _num_traits::ToPrimitive>(n: T) -> Option<Self> {
519+
impl #impl_ _num_traits::NumCast for #name #type_ #where_ #inner_ty: _num_traits::NumCast {
520+
#[allow(non_camel_case_types)]
521+
fn from<#fn_param: _num_traits::ToPrimitive>(n: #fn_param) -> Option<Self> {
498522
<#inner_ty as _num_traits::NumCast>::from(n).map(#name)
499523
}
500524
}
@@ -510,12 +534,17 @@ pub fn num_cast(input: TokenStream) -> TokenStream {
510534
pub fn zero(input: TokenStream) -> TokenStream {
511535
let ast: syn::DeriveInput = syn::parse(input).unwrap();
512536
let name = &ast.ident;
537+
let (impl_, type_, where_) = &ast.generics.split_for_impl();
513538
let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY);
539+
let where_ = match where_ {
540+
Some(where_) => quote!{ #where_, },
541+
None => quote!{ where },
542+
};
514543
dummy_const_trick(
515544
"Zero",
516545
&name,
517546
quote! {
518-
impl _num_traits::Zero for #name {
547+
impl #impl_ _num_traits::Zero for #name #type_ #where_ #inner_ty: _num_traits::Zero {
519548
fn zero() -> Self {
520549
#name(<#inner_ty as _num_traits::Zero>::zero())
521550
}
@@ -535,12 +564,17 @@ pub fn zero(input: TokenStream) -> TokenStream {
535564
pub fn one(input: TokenStream) -> TokenStream {
536565
let ast: syn::DeriveInput = syn::parse(input).unwrap();
537566
let name = &ast.ident;
567+
let (impl_, type_, where_) = &ast.generics.split_for_impl();
538568
let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY);
569+
let where_ = match where_ {
570+
Some(where_) => quote!{ #where_, },
571+
None => quote!{ where },
572+
};
539573
dummy_const_trick(
540574
"One",
541575
&name,
542576
quote! {
543-
impl _num_traits::One for #name {
577+
impl #impl_ _num_traits::One for #name #type_ #where_ #inner_ty: _num_traits::One + PartialEq {
544578
fn one() -> Self {
545579
#name(<#inner_ty as _num_traits::One>::one())
546580
}
@@ -560,12 +594,17 @@ pub fn one(input: TokenStream) -> TokenStream {
560594
pub fn num(input: TokenStream) -> TokenStream {
561595
let ast: syn::DeriveInput = syn::parse(input).unwrap();
562596
let name = &ast.ident;
597+
let (impl_, type_, where_) = &ast.generics.split_for_impl();
563598
let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY);
599+
let where_ = match where_ {
600+
Some(where_) => quote!{ #where_, },
601+
None => quote!{ where },
602+
};
564603
dummy_const_trick(
565604
"Num",
566605
&name,
567606
quote! {
568-
impl _num_traits::Num for #name {
607+
impl #impl_ _num_traits::Num for #name #type_ #where_ #inner_ty: _num_traits::Num {
569608
type FromStrRadixErr = <#inner_ty as _num_traits::Num>::FromStrRadixErr;
570609
fn from_str_radix(s: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
571610
<#inner_ty as _num_traits::Num>::from_str_radix(s, radix).map(#name)
@@ -584,12 +623,17 @@ pub fn num(input: TokenStream) -> TokenStream {
584623
pub fn float(input: TokenStream) -> TokenStream {
585624
let ast: syn::DeriveInput = syn::parse(input).unwrap();
586625
let name = &ast.ident;
626+
let (impl_, type_, where_) = &ast.generics.split_for_impl();
587627
let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY);
628+
let where_ = match where_ {
629+
Some(where_) => quote!{ #where_, },
630+
None => quote!{ where },
631+
};
588632
dummy_const_trick(
589633
"Float",
590634
&name,
591635
quote! {
592-
impl _num_traits::Float for #name {
636+
impl #impl_ _num_traits::Float for #name #type_ #where_ #inner_ty: _num_traits::Float {
593637
fn nan() -> Self {
594638
#name(<#inner_ty as _num_traits::Float>::nan())
595639
}

tests/generic_newtype.rs

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
extern crate num as num_renamed;
2+
#[macro_use]
3+
extern crate num_derive;
4+
5+
use crate::num_renamed::{Float, FromPrimitive, Num, NumCast, One, ToPrimitive, Zero};
6+
use std::ops::Neg;
7+
8+
#[derive(
9+
Debug,
10+
Clone,
11+
Copy,
12+
PartialEq,
13+
PartialOrd,
14+
ToPrimitive,
15+
FromPrimitive,
16+
NumOps,
17+
NumCast,
18+
One,
19+
Zero,
20+
Num,
21+
Float,
22+
)]
23+
struct MyThing<T: Cake>(T) where T: Lie;
24+
25+
trait Cake {}
26+
trait Lie {}
27+
28+
impl Cake for f32 {}
29+
impl Lie for f32 {}
30+
31+
impl<T: Neg<Output = T> + Cake + Lie> Neg for MyThing<T> {
32+
type Output = Self;
33+
fn neg(self) -> Self {
34+
MyThing(self.0.neg())
35+
}
36+
}
37+
38+
#[test]
39+
fn test_from_primitive() {
40+
assert_eq!(MyThing::from_u32(25), Some(MyThing(25.0)));
41+
}
42+
43+
#[test]
44+
fn test_from_primitive_128() {
45+
assert_eq!(
46+
MyThing::from_i128(std::i128::MIN),
47+
Some(MyThing((-2.0).powi(127)))
48+
);
49+
}
50+
51+
#[test]
52+
fn test_to_primitive() {
53+
assert_eq!(MyThing(25.0).to_u32(), Some(25));
54+
}
55+
56+
#[test]
57+
fn test_to_primitive_128() {
58+
let f: MyThing<f32> = MyThing::from_f32(std::f32::MAX).unwrap();
59+
assert_eq!(f.to_i128(), None);
60+
assert_eq!(f.to_u128(), Some(0xffff_ff00_0000_0000_0000_0000_0000_0000));
61+
}
62+
63+
#[test]
64+
fn test_num_ops() {
65+
assert_eq!(MyThing(25.0) + MyThing(10.0), MyThing(35.0));
66+
assert_eq!(MyThing(25.0) - MyThing(10.0), MyThing(15.0));
67+
assert_eq!(MyThing(25.0) * MyThing(2.0), MyThing(50.0));
68+
assert_eq!(MyThing(25.0) / MyThing(10.0), MyThing(2.5));
69+
assert_eq!(MyThing(25.0) % MyThing(10.0), MyThing(5.0));
70+
}
71+
72+
#[test]
73+
fn test_num_cast() {
74+
assert_eq!(<MyThing<f32> as NumCast>::from(25u8), Some(MyThing(25.0)));
75+
}
76+
77+
#[test]
78+
fn test_zero() {
79+
assert_eq!(MyThing::zero(), MyThing(0.0));
80+
}
81+
82+
#[test]
83+
fn test_one() {
84+
assert_eq!(MyThing::one(), MyThing(1.0));
85+
}
86+
87+
#[test]
88+
fn test_num() {
89+
assert_eq!(MyThing::from_str_radix("25", 10).ok(), Some(MyThing(25.0)));
90+
}
91+
92+
#[test]
93+
fn test_float() {
94+
assert_eq!(MyThing(4.0).log(MyThing(2.0)), MyThing(2.0));
95+
}

0 commit comments

Comments
 (0)