Skip to content
15 changes: 15 additions & 0 deletions internal/validation/suite.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,28 @@ func RunInstanceLifecycleValidation(t *testing.T, config ProviderConfig) {
require.NoError(t, err, "ValidateInstanceImage should pass")
})

t.Run("ValidateInboundPortRestriction", func(t *testing.T) {
err := v1.ValidateInboundPortRestriction(ctx, client, instance, ssh.GetTestPrivateKey())
require.NoError(t, err, "ValidateInboundPortRestriction should pass")
})

if capabilities.IsCapable(v1.CapabilityStopStartInstance) && instance.Stoppable {
t.Run("ValidateStopStartInstance", func(t *testing.T) {
err := v1.ValidateStopStartInstance(ctx, client, instance)
require.NoError(t, err, "ValidateStopStartInstance should pass")
})
}

t.Run("ValidateEastWestConnectivity", func(t *testing.T) {
attrs := v1.CreateInstanceAttrs{
InstanceType: instance.InstanceType,
Location: instance.Location,
PublicKey: ssh.GetTestPublicKey(),
}
err := v1.ValidateEastWestConnectivity(ctx, client, attrs, ssh.GetTestPrivateKey())
require.NoError(t, err, "ValidateEastWestConnectivity should pass")
})

t.Run("ValidateTerminateInstance", func(t *testing.T) {
err := v1.ValidateTerminateInstance(ctx, client, instance)
require.NoError(t, err, "ValidateTerminateInstance should pass")
Expand Down
210 changes: 209 additions & 1 deletion pkg/v1/networking.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
package v1

import "context"
import (
"context"
"fmt"

"github.com/brevdev/cloud/pkg/ssh"
"github.com/google/uuid"
)

type CloudModifyFirewall interface {
AddFirewallRulesToInstance(ctx context.Context, args AddFirewallRulesToInstanceArgs) error
Expand Down Expand Up @@ -33,3 +39,205 @@ type PortMapping struct {
FromPort int
ToPort int
}

func ValidateInboundPortRestriction(ctx context.Context, client CloudInstanceReader, instance *Instance, privateKey string) error {
var err error
instance, err = WaitForInstanceLifecycleStatus(ctx, client, instance, LifecycleStatusRunning, PendingToRunningTimeout)
if err != nil {
return err
}

if instance.SSHUser == "" {
return fmt.Errorf("SSH user is not set for instance %s", instance.CloudID)
}
if instance.SSHPort == 0 {
return fmt.Errorf("SSH port is not set for instance %s", instance.CloudID)
}
if instance.PublicIP == "" {
return fmt.Errorf("public IP is not available for instance %s", instance.CloudID)
}

sshClient, err := ssh.ConnectToHost(ctx, ssh.ConnectionConfig{
User: instance.SSHUser,
HostPort: fmt.Sprintf("%s:%d", instance.PublicIP, instance.SSHPort),
PrivKey: privateKey,
})
if err != nil {
return fmt.Errorf("failed to connect to instance via SSH: %w", err)
}
defer func() {
if closeErr := sshClient.Close(); closeErr != nil {
fmt.Printf("warning: failed to close SSH connection: %v\n", closeErr)
}
}()

portsToCheck := []int{21, 23, 25, 53, 80, 443, 993, 995, 3389, 5432, 3306}

for _, port := range portsToCheck {
cmd := fmt.Sprintf("timeout 5 nc -z %s %d", instance.PublicIP, port)
stdout, stderr, err := sshClient.RunCommand(ctx, cmd)

if err == nil {
return fmt.Errorf("security violation: port %d is accessible from external sources, stdout: %s, stderr: %s", port, stdout, stderr)
}

fmt.Printf("Port %d properly blocked (expected): %s\n", port, stderr)
}

cmd := fmt.Sprintf("timeout 5 nc -z %s %d", instance.PublicIP, instance.SSHPort)
stdout, stderr, err := sshClient.RunCommand(ctx, cmd)
if err != nil {
return fmt.Errorf("SSH port %d should be accessible but is not: %w, stdout: %s, stderr: %s", instance.SSHPort, err, stdout, stderr)
}

fmt.Printf("Inbound port validation passed: only SSH port %d is accessible\n", instance.SSHPort)
return nil
}

func ValidateEastWestConnectivity(ctx context.Context, client CloudCreateTerminateInstance, attrs CreateInstanceAttrs, privateKey string) error {
instance1, instance2, err := createTestInstances(ctx, client, attrs)
if err != nil {
return err
}

defer cleanupInstances(ctx, client, instance1, instance2)

instance1, instance2, err = waitForInstancesReady(ctx, client, instance1, instance2, privateKey)
if err != nil {
return err
}

return testConnectivity(ctx, instance1, instance2, privateKey)
}

func createTestInstances(ctx context.Context, client CloudCreateTerminateInstance, attrs CreateInstanceAttrs) (*Instance, *Instance, error) {
attrs1 := attrs
attrs1.RefID = uuid.New().String()
baseName := attrs.Name
if baseName == "" {
baseName = "test-connectivity"
}
name1, err := makeDebuggableName(fmt.Sprintf("%s-east", baseName))
if err != nil {
return nil, nil, fmt.Errorf("failed to generate debuggable name for first instance: %w", err)
}
attrs1.Name = name1

instance1, err := client.CreateInstance(ctx, attrs1)
if err != nil {
return nil, nil, fmt.Errorf("failed to create first instance: %w", err)
}

attrs2 := attrs
attrs2.RefID = uuid.New().String()
baseName2 := attrs.Name
if baseName2 == "" {
baseName2 = "test-connectivity"
}
name2, err := makeDebuggableName(fmt.Sprintf("%s-west", baseName2))
if err != nil {
return instance1, nil, fmt.Errorf("failed to generate debuggable name for second instance: %w", err)
}
attrs2.Name = name2

instance2, err := client.CreateInstance(ctx, attrs2)
if err != nil {
return instance1, nil, fmt.Errorf("failed to create second instance: %w", err)
}

return instance1, instance2, nil
}

func cleanupInstances(ctx context.Context, client CloudCreateTerminateInstance, instance1, instance2 *Instance) {
if instance1 != nil {
if termErr := client.TerminateInstance(ctx, instance1.CloudID); termErr != nil {
fmt.Printf("warning: failed to terminate first instance %s: %v\n", instance1.CloudID, termErr)
}
}
if instance2 != nil {
if termErr := client.TerminateInstance(ctx, instance2.CloudID); termErr != nil {
fmt.Printf("warning: failed to terminate second instance %s: %v\n", instance2.CloudID, termErr)
}
}
}

func waitForInstancesReady(ctx context.Context, client CloudCreateTerminateInstance, instance1, instance2 *Instance, privateKey string) (*Instance, *Instance, error) {
var err error
instance1, err = WaitForInstanceLifecycleStatus(ctx, client, instance1, LifecycleStatusRunning, PendingToRunningTimeout)
if err != nil {
return nil, nil, fmt.Errorf("first instance failed to reach running state: %w", err)
}

instance2, err = WaitForInstanceLifecycleStatus(ctx, client, instance2, LifecycleStatusRunning, PendingToRunningTimeout)
if err != nil {
return nil, nil, fmt.Errorf("second instance failed to reach running state: %w", err)
}

instance1, err = client.GetInstance(ctx, instance1.CloudID)
if err != nil {
return nil, nil, fmt.Errorf("failed to refresh first instance: %w", err)
}

instance2, err = client.GetInstance(ctx, instance2.CloudID)
if err != nil {
return nil, nil, fmt.Errorf("failed to refresh second instance: %w", err)
}

err = ssh.WaitForSSH(ctx, ssh.ConnectionConfig{
User: instance1.SSHUser,
HostPort: fmt.Sprintf("%s:%d", instance1.PublicIP, instance1.SSHPort),
PrivKey: privateKey,
}, ssh.WaitForSSHOptions{
Timeout: RunningSSHTimeout,
})
if err != nil {
return nil, nil, fmt.Errorf("SSH not available on first instance: %w", err)
}

err = ssh.WaitForSSH(ctx, ssh.ConnectionConfig{
User: instance2.SSHUser,
HostPort: fmt.Sprintf("%s:%d", instance2.PublicIP, instance2.SSHPort),
PrivKey: privateKey,
}, ssh.WaitForSSHOptions{
Timeout: RunningSSHTimeout,
})
if err != nil {
return nil, nil, fmt.Errorf("SSH not available on second instance: %w", err)
}

return instance1, instance2, nil
}

func testConnectivity(ctx context.Context, instance1, instance2 *Instance, privateKey string) error {
sshClient1, err := ssh.ConnectToHost(ctx, ssh.ConnectionConfig{
User: instance1.SSHUser,
HostPort: fmt.Sprintf("%s:%d", instance1.PublicIP, instance1.SSHPort),
PrivKey: privateKey,
})
if err != nil {
return fmt.Errorf("failed to connect to first instance via SSH: %w", err)
}
defer func() {
if closeErr := sshClient1.Close(); closeErr != nil {
fmt.Printf("warning: failed to close SSH connection to first instance: %v\n", closeErr)
}
}()

pingCmd := fmt.Sprintf("ping -c 3 -W 5 %s", instance2.PrivateIP)
stdout, stderr, err := sshClient1.RunCommand(ctx, pingCmd)
if err != nil {
return fmt.Errorf("ping from instance1 to instance2 failed: %w, stdout: %s, stderr: %s", err, stdout, stderr)
}

sshTestCmd := fmt.Sprintf("timeout 10 nc -z %s %d", instance2.PrivateIP, instance2.SSHPort)
stdout, stderr, err = sshClient1.RunCommand(ctx, sshTestCmd)
if err != nil {
fmt.Printf("SSH port connectivity test between instances failed (may be expected): %s, stderr: %s\n", stdout, stderr)
} else {
fmt.Printf("SSH port connectivity between instances successful: %s\n", stdout)
}

fmt.Printf("East-west connectivity validation passed: instance1 (%s) can communicate with instance2 (%s)\n",
instance1.CloudID, instance2.CloudID)
return nil
}
Loading