@@ -309,20 +309,34 @@ def get_local_mesh(self):
309309 X = [np .broadcast_to (x , self .real_shape ()) for x in X ]
310310 return X
311311
312- def get_local_wavenumbermesh (self , scaled = False ):
312+ def get_local_wavenumbermesh (self , scaled = False , broadcast = False ,
313+ eliminate_highest_freq = False ):
313314 """Returns (scaled) local decomposed wavenumbermesh
314315
315316 If scaled is True, then the wavenumbermesh is scaled with physical mesh
316317 size. This takes care of mapping the physical domain to a computational
317318 cube of size (2pi)**3
319+
320+
318321 """
319- kx , ky , kz = self .complex_local_wavenumbers ()
322+ s = self .complex_local_slice ()
323+ kx = fftfreq (self .N [0 ], 1. / self .N [0 ]).astype (int )
324+ ky = fftfreq (self .N [1 ], 1. / self .N [1 ]).astype (int )
325+ kz = rfftfreq (self .N [2 ], 1. / self .N [2 ]).astype (int )
326+ if eliminate_highest_freq :
327+ for i , k in enumerate ((kx , ky , kz )):
328+ if self .N [i ] % 2 == 0 :
329+ k [self .N [i ]// 2 ] = 0
330+ kx = kx [s [0 ]]
331+ kz = kz [s [2 ]]
320332 Ks = np .meshgrid (kx , ky , kz , indexing = 'ij' , sparse = True )
321333 if scaled is True :
322334 Lp = 2 * np .pi / self .L
323335 for i in range (3 ):
324336 Ks [i ] = (Ks [i ]* Lp [i ]).astype (self .float )
325- K = [np .broadcast_to (k , self .complex_shape ()) for k in Ks ]
337+ K = Ks
338+ if broadcast is True :
339+ K = [np .broadcast_to (k , self .complex_shape ()) for k in Ks ]
326340 return K
327341
328342 def get_dealias_filter (self ):
0 commit comments