Run a SQL Query Through SSH

This post shows how you can connect to a database remotely via a SSH connection. In our example we’ll be connecting to a MySQL or MariaDB database, but the same method will apply to many other SQL databases like PostgreSQL. This technique is especially useful if the database isn’t accessible due to firewall rules - for example on a web server. But if you have SSH access, it’s just like logging in and running the command yourself.

Code Example

Below is the example code and a demo gif showing it running, but before that we’ve listed some key parts of the code in the list below.

  1. Create an SSH agent (line 71)
  2. Read the .pem file (line 76)
  3. Create the client config & make a connection (line 86)
  4. Register the connection with the database controller (line 104)
  5. Make the database connection (line 110)
  6. Pass both back to be used until we close them (line 117)
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
package main

import (
	"database/sql"
	"fmt"
	"log"
	"net"
	"os"
	"context"

	"github.com/go-sql-driver/mysql"
	"golang.org/x/crypto/ssh"
	"golang.org/x/crypto/ssh/agent"
)

type ViaSSHDialer struct {
	client *ssh.Client
}

func (self *ViaSSHDialer) Dial(addr string) (net.Conn, error) {
	return self.client.Dial("tcp", addr)
}

type DatabaseCreds struct {
	SSHHost    string // SSH Server Hostname/IP
	SSHPort    int    // SSH Port
	SSHUser    string // SSH Username
	SSHKeyFile string // SSH Key file location
	DBUser     string // DB username
	DBPass     string // DB Password
	DBHost     string // DB Hostname/IP
	DBName     string // Database name
}

func main() {

	db, sshConn, err := ConnectToDB(DatabaseCreds{
		SSHHost:    "123.123.123.123",
		SSHPort:    22,
		SSHUser:    "root",
		SSHKeyFile: "sshkeyfile.pem",
		DBUser:     "root",
		DBPass:     "password",
		DBHost:     "localhost:3306",
		DBName:     "dname",
	})
	if err != nil {
		log.Fatal(err)
	}
	defer sshConn.Close()
	defer db.Close()

	if rows, err := db.Query("SELECT 1=1"); err == nil {
		for rows.Next() {
			var result string
			rows.Scan(&result)
			fmt.Printf("Result: %s\n", result)
		}
		rows.Close()
	} else {
		fmt.Printf("Failure: %s", err.Error())
	}
}

// ConnectToDB will accept the db and ssh credientials (DatabaseCreds) and
// form a connection with the database (handling any errors that might arise).
func ConnectToDB(dbCreds DatabaseCreds) (*sql.DB, *ssh.Client, error) {

	// Make SSH client: establish a connection to the local ssh-agent
	var agentClient agent.Agent
	if conn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil {
		defer conn.Close()
		agentClient = agent.NewClient(conn)
	}

	pemBytes, err := os.ReadFile(dbCreds.SSHKeyFile)
	if err != nil {
		return nil, nil, err
	}
	signer, err := ssh.ParsePrivateKey(pemBytes)
	if err != nil {
		return nil, nil, err
	}

	// The client configuration with configuration option to use the ssh-agent
	sshConfig := &ssh.ClientConfig{
		User:            dbCreds.SSHUser,
		Auth:            []ssh.AuthMethod{ssh.PublicKeys(signer)},
		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
	}

	// When the agentClient connection succeeded, add them as AuthMethod
	if agentClient != nil {
		sshConfig.Auth = append(sshConfig.Auth, ssh.PublicKeysCallback(agentClient.Signers))
	}

	// Connect to the SSH Server
	sshConn, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", dbCreds.SSHHost, dbCreds.SSHPort), sshConfig)
	if err != nil {
		return nil, nil, err
	}

	// Now we register the ViaSSHDialer with the ssh connection as a parameter
	mysql.RegisterDialContext("mysql+tcp", func(_ context.Context, addr string) (net.Conn, error) {
		dialer := &ViaSSHDialer{sshConn}
		return dialer.Dial(addr)
	})

	// And now we can use our new driver with the regular mysql connection string tunneled through the SSH connection
	db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mysql+tcp(%s)/%s", dbCreds.DBUser, dbCreds.DBPass, dbCreds.DBHost, dbCreds.DBName))
	if err != nil {
		return nil, sshConn, err
	}

	fmt.Println("Successfully connected to the db")

	return db, sshConn, err
}

As you can see, most of the work is done for us in ConnectToDB() all we have to do after is use it then close it.

Example In Action

golang connect to db over ssh

An Alternative Route (without Code)

This is also possible with port forwarding through an SSH connection on our machine, like shown below, but it’s often more consise and simpler being part of a single script.

1
ssh -N -L 8888:127.0.0.1:80 -i mykeyfile.pem user@host

Using Username/Password SSH (No KeyFile)

If your SSH connection does not have a keyfile, only a username and password, there is a different ssh AuthMethod for this (shown below). To use this, you’ll also have to remove some keyfile specific code from the example.

1
2
3
4
5
sshConfig := &ssh.ClientConfig{
	User:            dbCreds.SSHUser,
	Auth:            []ssh.AuthMethod{ssh.Password("mysecretpassword")},
	HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}