Skip to content

Commit c338ded

Browse files
authored
Merge pull request #3859 from unsuman/fix/ext-driver-shutdown-issue
Fix: Server Shutdown Bug
2 parents 0279010 + 62db1d4 commit c338ded

File tree

3 files changed

+50
-14
lines changed

3 files changed

+50
-14
lines changed

pkg/driver/external/client/methods.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,7 @@ func (d *DriverClient) Start(ctx context.Context) (chan error, error) {
9393
func (d *DriverClient) Stop(ctx context.Context) error {
9494
d.logger.Debug("Stopping driver instance")
9595

96-
connCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
97-
defer cancel()
98-
_, err := d.DriverSvc.Stop(connCtx, &emptypb.Empty{})
96+
_, err := d.DriverSvc.Stop(ctx, &emptypb.Empty{})
9997
if err != nil {
10098
d.logger.Errorf("Failed to stop driver instance: %v", err)
10199
return err

pkg/driver/external/server/server.go

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"os/exec"
1515
"os/signal"
1616
"path/filepath"
17+
"sync"
1718
"syscall"
1819
"time"
1920

@@ -35,6 +36,20 @@ type DriverServer struct {
3536
logger *logrus.Logger
3637
}
3738

39+
type listenerTracker struct {
40+
net.Listener
41+
connected chan struct{}
42+
once sync.Once
43+
}
44+
45+
func (t *listenerTracker) Accept() (net.Conn, error) {
46+
c, err := t.Listener.Accept()
47+
if err == nil {
48+
t.once.Do(func() { close(t.connected) })
49+
}
50+
return c, err
51+
}
52+
3853
func Serve(driver driver.Driver) {
3954
logger := logrus.New()
4055
logger.SetLevel(logrus.DebugLevel)
@@ -57,6 +72,11 @@ func Serve(driver driver.Driver) {
5772
}
5873
defer listener.Close()
5974

75+
tListener := &listenerTracker{
76+
Listener: listener,
77+
connected: make(chan struct{}),
78+
}
79+
6080
output := map[string]string{"socketPath": socketPath}
6181
if err := json.NewEncoder(os.Stdout).Encode(output); err != nil {
6282
logger.Fatalf("Failed to encode socket path as JSON: %v", err)
@@ -86,21 +106,26 @@ func Serve(driver driver.Driver) {
86106
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
87107

88108
shutdownCh := make(chan struct{})
109+
var closeOnce sync.Once
110+
closeShutdown := func() { closeOnce.Do(func() { close(shutdownCh) }) }
89111

90112
go func() {
91113
<-sigs
92114
logger.Info("Received shutdown signal, stopping server...")
93-
close(shutdownCh)
115+
closeShutdown()
94116
}()
95117

96118
go func() {
97119
timer := time.NewTimer(60 * time.Second)
98120
defer timer.Stop()
99121

100122
select {
123+
case <-tListener.connected:
124+
logger.Debug("Client connected; disabling 60s startup shutdown")
125+
return
101126
case <-timer.C:
102127
logger.Info("No client connected within 60 seconds, shutting down server...")
103-
close(shutdownCh)
128+
closeShutdown()
104129
case <-shutdownCh:
105130
return
106131
}
@@ -109,7 +134,7 @@ func Serve(driver driver.Driver) {
109134
go func() {
110135
logger.Infof("Starting external driver server for %s", driver.Info().DriverName)
111136
logger.Infof("Server starting on Unix socket: %s", socketPath)
112-
if err := server.Serve(listener); err != nil {
137+
if err := server.Serve(tListener); err != nil {
113138
if errors.Is(err, grpc.ErrServerStopped) {
114139
logger.Errorf("Server stopped: %v", err)
115140
} else {

pkg/driver/qemu/qemu_driver.go

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ func (l *LimaQemuDriver) CreateDisk(ctx context.Context) error {
9696
return EnsureDisk(ctx, qCfg)
9797
}
9898

99-
func (l *LimaQemuDriver) Start(ctx context.Context) (chan error, error) {
100-
ctx, cancel := context.WithCancel(ctx)
99+
func (l *LimaQemuDriver) Start(_ context.Context) (chan error, error) {
100+
ctx, cancel := context.WithCancel(context.Background())
101101
defer func() {
102102
if l.qCmd == nil {
103103
cancel()
@@ -220,8 +220,9 @@ func (l *LimaQemuDriver) Start(ctx context.Context) (chan error, error) {
220220
return nil, err
221221
}
222222
l.qCmd = qCmd
223-
l.qWaitCh = make(chan error)
223+
l.qWaitCh = make(chan error, 1)
224224
go func() {
225+
defer close(l.qWaitCh)
225226
l.qWaitCh <- qCmd.Wait()
226227
}()
227228
l.vhostCmds = vhostCmds
@@ -380,20 +381,32 @@ func (l *LimaQemuDriver) shutdownQEMU(ctx context.Context, timeout time.Duration
380381
logrus.WithError(err).Warnf("failed to send system_powerdown command via the QMP socket %q, forcibly killing QEMU", qmpSockPath)
381382
return l.killQEMU(ctx, timeout, qCmd, qWaitCh)
382383
}
383-
deadline := time.After(timeout)
384+
timeoutCtx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
385+
defer timeoutCancel()
386+
384387
select {
385-
case qWaitErr := <-qWaitCh:
388+
case qWaitErr, ok := <-qWaitCh:
389+
if !ok {
390+
logrus.Info("QEMU wait channel was closed")
391+
_ = l.removeVNCFiles()
392+
return l.killVhosts()
393+
}
386394
entry := logrus.NewEntry(logrus.StandardLogger())
387395
if qWaitErr != nil {
388396
entry = entry.WithError(qWaitErr)
389397
}
390398
entry.Info("QEMU has exited")
391399
_ = l.removeVNCFiles()
392400
return errors.Join(qWaitErr, l.killVhosts())
393-
case <-deadline:
401+
case <-timeoutCtx.Done():
402+
if qCmd.ProcessState != nil {
403+
logrus.Info("QEMU has already exited")
404+
_ = l.removeVNCFiles()
405+
return l.killVhosts()
406+
}
407+
logrus.Warnf("QEMU did not exit in %v, forcibly killing QEMU", timeout)
408+
return l.killQEMU(ctx, timeout, qCmd, qWaitCh)
394409
}
395-
logrus.Warnf("QEMU did not exit in %v, forcibly killing QEMU", timeout)
396-
return l.killQEMU(ctx, timeout, qCmd, qWaitCh)
397410
}
398411

399412
func (l *LimaQemuDriver) killQEMU(_ context.Context, _ time.Duration, qCmd *exec.Cmd, qWaitCh <-chan error) error {

0 commit comments

Comments
 (0)