@@ -102,70 +102,38 @@ end
102102
103103@device_override Base. log (x:: Float64 ) = ccall (" extern __nv_log" , llvmcall, Cdouble, (Cdouble,), x)
104104@device_override Base. log (x:: Float32 ) = ccall (" extern __nv_logf" , llvmcall, Cfloat, (Cfloat,), x)
105- @device_override function Base. log (x:: Float16 )
106- log_x = @asmcall (""" {.reg.b32 f, C;
107- .reg.b16 r,h;
108- mov.b16 h,\$ 1;
109- cvt.f32.f16 f,h;
110- lg2.approx.ftz.f32 f,f;
111- mov.b32 C, 0x3f317218U;
112- mul.f32 f,f,C;
113- cvt.rn.f16.f32 r,f;
114- .reg.b16 spc, ulp, p;
115- mov.b16 spc, 0X160DU;
116- mov.b16 ulp, 0x9C00U;
117- set.eq.f16.f16 p, h, spc;
118- fma.rn.f16 r,p,ulp,r;
119- mov.b16 spc, 0X3BFEU;
120- mov.b16 ulp, 0x8010U;
121- set.eq.f16.f16 p, h, spc;
122- fma.rn.f16 r,p,ulp,r;
123- mov.b16 spc, 0X3C0BU;
124- mov.b16 ulp, 0x8080U;
125- set.eq.f16.f16 p, h, spc;
126- fma.rn.f16 r,p,ulp,r;
127- mov.b16 spc, 0X6051U;
128- mov.b16 ulp, 0x1C00U;
129- set.eq.f16.f16 p, h, spc;
130- fma.rn.f16 r,p,ulp,r;
131- mov.b16 \$ 0,r;
132- }""" , " =h,h" , Float16, Tuple{Float16}, x)
133- return log_x
105+ @device_override function Base. log (h:: Float16 )
106+ # perform computation in Float32 domain
107+ f = Float32 (h)
108+ f = @fastmath log (f)
109+ r = Float16 (f)
110+
111+ # handle degenrate cases
112+ r = fma (Float16 (h == reinterpret (Float16, 0x160D )), reinterpret (Float16, 0x9C00 ), r)
113+ r = fma (Float16 (h == reinterpret (Float16, 0x3BFE )), reinterpret (Float16, 0x8010 ), r)
114+ r = fma (Float16 (h == reinterpret (Float16, 0x3C0B )), reinterpret (Float16, 0x8080 ), r)
115+ r = fma (Float16 (h == reinterpret (Float16, 0x6051 )), reinterpret (Float16, 0x1C00 ), r)
116+
117+ return r
134118end
135119
136120@device_override FastMath. log_fast (x:: Float32 ) = ccall (" extern __nv_fast_logf" , llvmcall, Cfloat, (Cfloat,), x)
137121
138122@device_override Base. log10 (x:: Float64 ) = ccall (" extern __nv_log10" , llvmcall, Cdouble, (Cdouble,), x)
139123@device_override Base. log10 (x:: Float32 ) = ccall (" extern __nv_log10f" , llvmcall, Cfloat, (Cfloat,), x)
140- @device_override function Base. log10 (x:: Float16 )
141- log_x = @asmcall (""" {.reg.b16 h, r;
142- .reg.b32 f, C;
143- mov.b16 h, \$ 1;
144- cvt.f32.f16 f, h;
145- lg2.approx.ftz.f32 f, f;
146- mov.b32 C, 0x3E9A209BU;
147- mul.f32 f,f,C;
148- cvt.rn.f16.f32 r, f;
149- .reg.b16 spc, ulp, p;
150- mov.b16 spc, 0x338FU;
151- mov.b16 ulp, 0x1000U;
152- set.eq.f16.f16 p, h, spc;
153- fma.rn.f16 r,p,ulp,r;
154- mov.b16 spc, 0x33F8U;
155- mov.b16 ulp, 0x9000U;
156- set.eq.f16.f16 p, h, spc;
157- fma.rn.f16 r,p,ulp,r;
158- mov.b16 spc, 0x57E1U;
159- mov.b16 ulp, 0x9800U;
160- set.eq.f16.f16 p, h, spc;
161- fma.rn.f16 r,p,ulp,r;
162- mov.b16 spc, 0x719DU;
163- mov.b16 ulp, 0x9C00U;
164- set.eq.f16.f16 p, h, spc;
165- fma.rn.f16 r,p,ulp,r;
166- mov.b16 \$ 0, r;
167- }""" , " =h,h" , Float16, Tuple{Float16}, x)
168- return log_x
124+ @device_override function Base. log10 (h:: Float16 )
125+ # perform computation in Float32 domain
126+ f = Float32 (h)
127+ f = @fastmath log10 (f)
128+ r = Float16 (f)
129+
130+ # handle degenerate cases
131+ r = fma (Float16 (h == reinterpret (Float16, 0x338F )), reinterpret (Float16, 0x1000 ), r)
132+ r = fma (Float16 (h == reinterpret (Float16, 0x33F8 )), reinterpret (Float16, 0x9000 ), r)
133+ r = fma (Float16 (h == reinterpret (Float16, 0x57E1 )), reinterpret (Float16, 0x9800 ), r)
134+ r = fma (Float16 (h == reinterpret (Float16, 0x719D )), reinterpret (Float16, 0x9C00 ), r)
135+
136+ return r
169137end
170138@device_override FastMath. log10_fast (x:: Float32 ) = ccall (" extern __nv_fast_log10f" , llvmcall, Cfloat, (Cfloat,), x)
171139
@@ -174,25 +142,17 @@ end
174142
175143@device_override Base. log2 (x:: Float64 ) = ccall (" extern __nv_log2" , llvmcall, Cdouble, (Cdouble,), x)
176144@device_override Base. log2 (x:: Float32 ) = ccall (" extern __nv_log2f" , llvmcall, Cfloat, (Cfloat,), x)
177- @device_override function Base. log2 (x:: Float16 )
178- log_x = @asmcall (""" {.reg.b16 h, r;
179- .reg.b32 f;
180- mov.b16 h, \$ 1;
181- cvt.f32.f16 f, h;
182- lg2.approx.ftz.f32 f, f;
183- cvt.rn.f16.f32 r, f;
184- .reg.b16 spc, ulp, p;
185- mov.b16 spc, 0xA2E2U;
186- mov.b16 ulp, 0x8080U;
187- set.eq.f16.f16 p, r, spc;
188- fma.rn.f16 r,p,ulp,r;
189- mov.b16 spc, 0xBF46U;
190- mov.b16 ulp, 0x9400U;
191- set.eq.f16.f16 p, r, spc;
192- fma.rn.f16 r,p,ulp,r;
193- mov.b16 \$ 0, r;
194- }""" , " =h,h" , Float16, Tuple{Float16}, x)
195- return log_x
145+ @device_override function Base. log2 (h:: Float16 )
146+ # perform computation in Float32 domain
147+ f = Float32 (h)
148+ f = @fastmath log2 (f)
149+ r = Float16 (f)
150+
151+ # handle degenerate cases
152+ r = fma (Float16 (r == reinterpret (Float16, 0xA2E2 )), reinterpret (Float16, 0x8080 ), r)
153+ r = fma (Float16 (r == reinterpret (Float16, 0xBF46 )), reinterpret (Float16, 0x9400 ), r)
154+
155+ return r
196156end
197157@device_override FastMath. log2_fast (x:: Float32 ) = ccall (" extern __nv_fast_log2f" , llvmcall, Cfloat, (Cfloat,), x)
198158
@@ -207,94 +167,55 @@ end
207167
208168@device_override Base. exp (x:: Float64 ) = ccall (" extern __nv_exp" , llvmcall, Cdouble, (Cdouble,), x)
209169@device_override Base. exp (x:: Float32 ) = ccall (" extern __nv_expf" , llvmcall, Cfloat, (Cfloat,), x)
210- @device_override function Base. exp (x:: Float16 )
211- exp_x = @asmcall (""" {
212- .reg.b32 f, C, nZ;
213- .reg.b16 h,r;
214- mov.b16 h,\$ 1;
215- cvt.f32.f16 f,h;
216- mov.b32 C, 0x3fb8aa3bU;
217- mov.b32 nZ, 0x80000000U;
218- fma.rn.f32 f,f,C,nZ;
219- ex2.approx.ftz.f32 f,f;
220- cvt.rn.f16.f32 r,f;
221- .reg.b16 spc, ulp, p;
222- mov.b16 spc,0X1F79U;
223- mov.b16 ulp,0x9400U;
224- set.eq.f16.f16 p, h, spc;
225- fma.rn.f16 r,p,ulp,r;
226- mov.b16 spc,0X25CFU;
227- mov.b16 ulp,0x9400U;
228- set.eq.f16.f16 p, h, spc;
229- fma.rn.f16 r,p,ulp,r;
230- mov.b16 spc,0XC13BU;
231- mov.b16 ulp,0x0400U;
232- set.eq.f16.f16 p, h, spc;
233- fma.rn.f16 r,p,ulp,r;
234- mov.b16 spc,0XC1EFU;
235- mov.b16 ulp,0x0200U;
236- set.eq.f16.f16 p, h, spc;
237- fma.rn.f16 r,p,ulp,r;
238- mov.b16 \$ 0,r;
239- }""" , " =h,h" , Float16, Tuple{Float16}, x)
240- return exp_x
170+ @device_override function Base. exp (h:: Float16 )
171+ # perform computation in Float32 domain
172+ f = Float32 (h)
173+ f = fma (f, reinterpret (Float32, 0x3fb8aa3b ), reinterpret (Float32, Base. sign_mask (Float32)))
174+ f = @fastmath exp2 (f)
175+ r = Float16 (f)
176+
177+ # handle degenerate cases
178+ r = fma (Float16 (h == reinterpret (Float16, 0x1F79 )), reinterpret (Float16, 0x9400 ), r)
179+ r = fma (Float16 (h == reinterpret (Float16, 0x25CF )), reinterpret (Float16, 0x9400 ), r)
180+ r = fma (Float16 (h == reinterpret (Float16, 0xC13B )), reinterpret (Float16, 0x0400 ), r)
181+ r = fma (Float16 (h == reinterpret (Float16, 0xC1EF )), reinterpret (Float16, 0x0200 ), r)
182+
183+ return r
241184end
242185@device_override FastMath. exp_fast (x:: Float32 ) = ccall (" extern __nv_fast_expf" , llvmcall, Cfloat, (Cfloat,), x)
243186
244187@device_override Base. exp2 (x:: Float64 ) = ccall (" extern __nv_exp2" , llvmcall, Cdouble, (Cdouble,), x)
245188@device_override Base. exp2 (x:: Float32 ) = ccall (" extern __nv_exp2f" , llvmcall, Cfloat, (Cfloat,), x)
246- @device_override function Base. exp2 (x:: Float16 )
247- exp_x = @asmcall (""" {.reg.b32 f, ULP;
248- .reg.b16 r;
249- mov.b16 r,\$ 1;
250- cvt.f32.f16 f,r;
251- ex2.approx.ftz.f32 f,f;
252- mov.b32 ULP, 0x33800000U;
253- fma.rn.f32 f,f,ULP,f;
254- cvt.rn.f16.f32 r,f;
255- mov.b16 \$ 0,r;
256- }""" , " =h,h" , Float16, Tuple{Float16}, x)
257- return exp_x
189+ @device_override function Base. exp2 (h:: Float16 )
190+ # perform computation in Float32 domain
191+ f = Float32 (h)
192+ f = @fastmath exp2 (f)
193+
194+ # one ULP adjustement
195+ f = muladd (f, reinterpret (Float32, 0x33800000 ), f)
196+ r = Float16 (f)
197+
198+ return r
258199end
259200@device_override FastMath. exp2_fast (x:: Union{Float32, Float64} ) = exp2 (x)
260201
261202@device_override Base. exp10 (x:: Float64 ) = ccall (" extern __nv_exp10" , llvmcall, Cdouble, (Cdouble,), x)
262203@device_override Base. exp10 (x:: Float32 ) = ccall (" extern __nv_exp10f" , llvmcall, Cfloat, (Cfloat,), x)
263- @device_override function Base. exp10 (x:: Float16 )
264-
265- exp_x = @asmcall (""" {.reg.b16 h,r;
266- .reg.b32 f, C, nZ;
267- mov.b16 h, \$ 1;
268- cvt.f32.f16 f, h;
269- mov.b32 C, 0x40549A78U;
270- mov.b32 nZ, 0x80000000U;
271- fma.rn.f32 f,f,C,nZ;
272- ex2.approx.ftz.f32 f, f;
273- cvt.rn.f16.f32 r, f;
274- .reg.b16 spc, ulp, p;
275- mov.b16 spc,0x34DEU;
276- mov.b16 ulp,0x9800U;
277- set.eq.f16.f16 p, h, spc;
278- fma.rn.f16 r,p,ulp,r;
279- mov.b16 spc,0x9766U;
280- mov.b16 ulp,0x9000U;
281- set.eq.f16.f16 p, h, spc;
282- fma.rn.f16 r,p,ulp,r;
283- mov.b16 spc,0x9972U;
284- mov.b16 ulp,0x1000U;
285- set.eq.f16.f16 p, h, spc;
286- fma.rn.f16 r,p,ulp,r;
287- mov.b16 spc,0xA5C4U;
288- mov.b16 ulp,0x1000U;
289- set.eq.f16.f16 p, h, spc;
290- fma.rn.f16 r,p,ulp,r;
291- mov.b16 spc,0xBF0AU;
292- mov.b16 ulp,0x8100U;
293- set.eq.f16.f16 p, h, spc;
294- fma.rn.f16 r,p,ulp,r;
295- mov.b16 \$ 0, r;
296- }""" , " =h,h" , Float16, Tuple{Float16}, x)
297- return exp_x
204+ @device_override function Base. exp10 (h:: Float16 )
205+ # perform computation in Float32 domain
206+ f = Float32 (h)
207+ f = fma (f, reinterpret (Float32, 0x40549A78 ), reinterpret (Float32, Base. sign_mask (Float32)))
208+ f = @fastmath exp2 (f)
209+ r = Float16 (f)
210+
211+ # handle degenerate cases
212+ r = fma (Float16 (h == reinterpret (Float16, 0x34DE )), reinterpret (Float16, 0x9800 ), r)
213+ r = fma (Float16 (h == reinterpret (Float16, 0x9766 )), reinterpret (Float16, 0x9000 ), r)
214+ r = fma (Float16 (h == reinterpret (Float16, 0x9972 )), reinterpret (Float16, 0x1000 ), r)
215+ r = fma (Float16 (h == reinterpret (Float16, 0xA5C4 )), reinterpret (Float16, 0x1000 ), r)
216+ r = fma (Float16 (h == reinterpret (Float16, 0xBF0A )), reinterpret (Float16, 0x8100 ), r)
217+
218+ return r
298219end
299220@device_override FastMath. exp10_fast (x:: Float32 ) = ccall (" extern __nv_fast_exp10f" , llvmcall, Cfloat, (Cfloat,), x)
300221
0 commit comments