Skip to content

Commit

Permalink
add roi selection
Browse files Browse the repository at this point in the history
  • Loading branch information
SGM4 committed Jun 30, 2023
1 parent 62cdbf0 commit a0ce888
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 26 deletions.
50 changes: 37 additions & 13 deletions smartscan/asyncscanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from smartscan.gp import fvGPOptimizer, ndim_aqfunc, compute_costs,plot_acqui_f
from smartscan.sgm4commands import SGM4Commands
from smartscan.utils import closest_point_on_grid
from smartscan.reductions import compose, sharpness, mean_std
from smartscan.reductions import compose, sharpness, mean_std, select_roi
import matplotlib.pyplot as plt

optimizer_pars = {
Expand Down Expand Up @@ -101,6 +101,7 @@ def __init__(
self.use_cost_function = use_cost_function
self.replot = False
self.batch_normalize = batch_normalize
self.last_spectrum = None

if logger is not None:

Expand Down Expand Up @@ -188,13 +189,7 @@ def update_data_and_positions(self):
n_new += 1
except asyncio.QueueEmpty:
break
if self.gp is not None:
pos = np.asarray(self.positions)
vals = np.asarray(self.values)
weights = np.asarray([100,1000])
if self.batch_normalize:
vals = weights * vals / np.mean(vals)
self.gp.tell(pos,vals)
self.tell_gp()
self.logger.info(f'Updated data with {n_new} new points. Total: {len(self.positions)} last Pos {self.positions[-1]} {self.values[-1]}.')
return True
else:
Expand All @@ -221,11 +216,17 @@ async def reduction_loop(self):
if self._raw_data_queue.qsize() > 0:
pos,data = await self._raw_data_queue.get()
self.logger.debug(f'reducing data with shape {data.shape}')
data = select_roi(
data.reshape((640,400)),
x_lim = (330,500),
y_lim = (20,380),
)
self.last_spectrum = data.copy()
reduced = compose(
data.reshape((640,400)),
data,
np.mean,
sharpness,
func_b_kwargs = {'sigma':3, 'r':1, 'reduce':np.mean}
func_b_kwargs = {'sigma':5, 'r':1, 'reduce':np.mean}
)
# reduced = reduced * 1000
# self.logger.info(f'adding {(pos,reduced)} to processed queue')
Expand Down Expand Up @@ -257,7 +258,7 @@ async def gp_loop(self):
output_space_dimension = 1,
output_number = 2,
)
self.gp.tell(np.asarray(self.positions), np.asarray(self.values))
self.tell_gp()
self.logger.info(f'Initialized GP with {len(self.positions)} samples.')
self.gp.init_fvgp(**fvgp_pars)
self.logger.info('Initialized GP. Training...')
Expand All @@ -282,8 +283,15 @@ async def gp_loop(self):
if retrain:
print('############################################################\n\n\n')
self.logger.info(f'Training GP at iteration {iter_counter}, with {len(self.positions)} samples.')
print('\n\n\n############################################################')

old_params = self.gp.hyperparameters.copy()
self.gp.train_gp(**train_pars)
new_params = self.gp.hyperparameters.copy()
s = 'hyperparams: '
for new,old in zip(new_params,old_params):
s += f"{new:,.2f} ({(new-old)/old:.2%}) |"
self.logger.debug(s)
print('\n\n\n############################################################')
answer = self.gp.ask(**ask_pars, acquisition_function=ndim_aqfunc)
next_pos = answer['x']
try:
Expand All @@ -298,24 +306,40 @@ async def gp_loop(self):
await asyncio.sleep(.2)

self.remote.END()

def tell_gp(self):
if self.gp is not None:
pos = np.asarray(self.positions)
vals = np.asarray(self.values)
weights = np.asarray([100,10_000])
if self.batch_normalize:
vals = weights * vals / np.mean(vals)
self.gp.tell(pos,vals)


async def plotting_loop(self):
self.logger.info('starting plotting tool loop')
fig = None
aqf = None

iteration = 0
while not self._should_stop:
iteration += 1
if self.replot:
self.replot = False
fig, aqf = plot_acqui_f(
gp=self.gp,
fig=fig,
pos=np.asarray(self.positions),
val=np.asarray(self.values),
old_aqf = aqf
old_aqf = aqf,
last_spectrum=self.last_spectrum,
)
plt.pause(0.01)
else:
await asyncio.sleep(.2)
# if fig is not None and iteration %100 == 0:
# fig.savefig(f'../results/{self.remote.filename.with_suffix("pdf").name}')

async def all_loops(self):
"""
Expand Down
38 changes: 25 additions & 13 deletions smartscan/gp/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,15 @@ def plot_map_with_path_and_scatterplot(positions,values,reduced_maps):
ax[0].scatter(positions[:,1],positions[:,0],s=10,c='b')
ax[1].scatter(positions[:,1],positions[:,0],s=25,c=-values[:,0],cmap='viridis',marker='s')

def plot_acqui_f(gp, fig, pos, val, shape=(50,50), old_aqf = None):
def plot_acqui_f(
gp,
fig,
pos,
val,
shape=(50,50),
old_aqf = None,
last_spectrum = None,
):
""" Plot the acquisition function of a GP
Args:
Expand Down Expand Up @@ -85,18 +93,19 @@ def plot_acqui_f(gp, fig, pos, val, shape=(50,50), old_aqf = None):
fig.clear()

ax = [
fig.add_subplot(241),
fig.add_subplot(242),
fig.add_subplot(243),
fig.add_subplot(244),
fig.add_subplot(245),
fig.add_subplot(246),
fig.add_subplot(247),
fig.add_subplot(248),
fig.add_subplot(331),
fig.add_subplot(332),
fig.add_subplot(333),
fig.add_subplot(334),
fig.add_subplot(335),
fig.add_subplot(336),
fig.add_subplot(337),
fig.add_subplot(338),
fig.add_subplot(339),
]

# fig,ax = plt.subplots(2,2,)
ax = np.asarray(ax).reshape(2,4)
ax = np.asarray(ax).reshape(3,3)
for i, PM, PV in zip(range(2),[PM0,PM1], [sPV0,sPV1]):
PM = np.rot90(PM,k=-1)[:,::-1]
PV = np.rot90(PV,k=-1)[:,::-1]
Expand All @@ -121,21 +130,24 @@ def plot_acqui_f(gp, fig, pos, val, shape=(50,50), old_aqf = None):
ax[0,2].scatter(pos[-1,0],pos[-1,1],s = 25, c='r', marker='o')
ax[1,2].scatter(pos[-1,0],pos[-1,1],s = 25, c='r', marker='o')

ax[0,3].set_title(f'Aq func {aqf.max():.2f}')
ax[0,3].imshow(
ax[2,0].set_title(f'Aq func {aqf.max():.2f}')
ax[2,0].imshow(
aqf,
extent=[*lim_x,*lim_y],
origin='lower',
clim=np.quantile(aqf,(0.01,0.99))
)
if old_aqf is not None:
diff = old_aqf - aqf
ax[1,3].imshow(
ax[2,1].set_title('aqf changes')
ax[2,1].imshow(
diff,
extent=[*lim_x,*lim_y],
origin='lower',
cmap='bwr'
)
if last_spectrum is not None:
ax[2,2].imshow(last_spectrum, clim=np.quantile(last_spectrum,(0.02,0.98)), origin='lower', cmap='terrain')
# ax[i,0].figure.canvas.draw()
# ax[i,1].figure.canvas.draw()

Expand Down

0 comments on commit a0ce888

Please sign in to comment.