Run an SQL Query Through SSH

This post shows you how you can connect to a database remotely via an SSH connection. This is especially useful if the database isn’t accessible due to firewall rules - for example on a web server. But if you have SSH access, it’s just like logging in and running the command yourself.

This is possible with port forwarding through an SSH connection on our machine, like shown below, but it’s more consise and simpler being part of a single script. So instead we will do this in Go (golang).

1
ssh -N -L 8888:127.0.0.1:80 -i mykeyfile.pem user@host

How the Code works

  1. Create an SSH agent (line 67)
  2. Read the .pem file (line 76)
  3. Create the client config & make a connection (line 86)
  4. Register the connection with the database controller (line 104)
  5. Make the database connection (line 107)
  6. Pass both back to be used until we close them (line 114)
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
package main

import (
	"database/sql"
	"fmt"
	"log"
	"net"
	"os"
	"context"

	"github.com/go-sql-driver/mysql"
	"golang.org/x/crypto/ssh"
	"golang.org/x/crypto/ssh/agent"
)

type ViaSSHDialer struct {
	client *ssh.Client
}

func (self *ViaSSHDialer) Dial(addr string) (net.Conn, error) {
	return self.client.Dial("tcp", addr)
}

type DatabaseCreds struct {
	SSHHost    string // SSH Server Hostname/IP
	SSHPort    int    // SSH Port
	SSHUser    string // SSH Username
	SSHKeyFile string // SSH Key file location
	DBUser     string // DB username
	DBPass     string // DB Password
	DBHost     string // DB Hostname/IP
	DBName     string // Database name
}

func main() {

	db, sshConn, err := ConnectToDB(DatabaseCreds{
		SSHHost:    "123.123.123.123",
		SSHPort:    22,
		SSHUser:    "root",
		SSHKeyFile: "sshkeyfile.pem",
		DBUser:     "root",
		DBPass:     "password",
		DBHost:     "localhost:3306",
		DBName:     "dname",
	})
	if err != nil {
		log.Fatal(err)
	}
	defer sshConn.Close()
	defer db.Close()

	if rows, err := db.Query("SELECT 1=1"); err == nil {
		for rows.Next() {
			var result string
			rows.Scan(&result)
			fmt.Printf("Result: %s\n", result)
		}
		rows.Close()
	} else {
		fmt.Printf("Failure: %s", err.Error())
	}
}

func ConnectToDB(dbCreds DatabaseCreds) (*sql.DB, *ssh.Client, error) {

	var agentClient agent.Agent
	// Establish a connection to the local ssh-agent
	if conn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil {
		defer conn.Close()

		// Create a new instance of the ssh agent
		agentClient = agent.NewClient(conn)
	}

	pemBytes, err := os.ReadFile(dbCreds.SSHKeyFile)
	if err != nil {
		return nil, nil, err
	}
	signer, err := ssh.ParsePrivateKey(pemBytes)
	if err != nil {
		return nil, nil, err
	}

	// The client configuration with configuration option to use the ssh-agent
	sshConfig := &ssh.ClientConfig{
		User:            dbCreds.SSHUser,
		Auth:            []ssh.AuthMethod{ssh.PublicKeys(signer)},
		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
	}

	// When the agentClient connection succeeded, add them as AuthMethod
	if agentClient != nil {
		sshConfig.Auth = append(sshConfig.Auth, ssh.PublicKeysCallback(agentClient.Signers))
	}

	// Connect to the SSH Server
	sshConn, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", dbCreds.SSHHost, dbCreds.SSHPort), sshConfig)
	if err != nil {
		return nil, nil, err
	}

	// Now we register the ViaSSHDialer with the ssh connection as a parameter
	mysql.RegisterDialContext("mysql+tcp", func(_ context.Context, addr string) (net.Conn, error) {
		dialer := &ViaSSHDialer{sshConn}
		return dialer.Dial(addr)
	})

	// And now we can use our new driver with the regular mysql connection string tunneled through the SSH connection
	db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mysql+tcp(%s)/%s", dbCreds.DBUser, dbCreds.DBPass, dbCreds.DBHost, dbCreds.DBName))
	if err != nil {
		return nil, sshConn, err
	}

	fmt.Println("Successfully connected to the db")

	return db, sshConn, err
}

As you can see, most of the work is done for us in ConnectToDB() all we have to do after is use it then close it.

Example In Action

golang connect to db over ssh