remove storage/mysql
This commit is contained in:
parent
3fcad1ec13
commit
bea345a84c
|
@ -43,7 +43,6 @@ import (
|
|||
|
||||
physConsul "github.com/hashicorp/vault/physical/consul"
|
||||
physFoundationDB "github.com/hashicorp/vault/physical/foundationdb"
|
||||
physMySQL "github.com/hashicorp/vault/physical/mysql"
|
||||
physOCI "github.com/hashicorp/vault/physical/oci"
|
||||
physRaft "github.com/hashicorp/vault/physical/raft"
|
||||
physFile "github.com/hashicorp/vault/sdk/physical/file"
|
||||
|
@ -174,7 +173,6 @@ var (
|
|||
"inmem_transactional_ha": physInmem.NewTransactionalInmemHA,
|
||||
"inmem_transactional": physInmem.NewTransactionalInmem,
|
||||
"inmem": physInmem.NewInmem,
|
||||
"mysql": physMySQL.NewMySQLBackend,
|
||||
"oci": physOCI.NewBackend,
|
||||
"raft": physRaft.NewRaftBackend,
|
||||
}
|
||||
|
|
|
@ -1,779 +0,0 @@
|
|||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
|
||||
metrics "github.com/armon/go-metrics"
|
||||
mysql "github.com/go-sql-driver/mysql"
|
||||
"github.com/hashicorp/go-secure-stdlib/strutil"
|
||||
"github.com/hashicorp/vault/sdk/physical"
|
||||
)
|
||||
|
||||
// Verify MySQLBackend satisfies the correct interfaces
|
||||
var (
|
||||
_ physical.Backend = (*MySQLBackend)(nil)
|
||||
_ physical.HABackend = (*MySQLBackend)(nil)
|
||||
_ physical.Lock = (*MySQLHALock)(nil)
|
||||
)
|
||||
|
||||
// Unreserved tls key
|
||||
// Reserved values are "true", "false", "skip-verify"
|
||||
const mysqlTLSKey = "default"
|
||||
|
||||
// MySQLBackend is a physical backend that stores data
|
||||
// within MySQL database.
|
||||
type MySQLBackend struct {
|
||||
dbTable string
|
||||
dbLockTable string
|
||||
client *sql.DB
|
||||
statements map[string]*sql.Stmt
|
||||
logger log.Logger
|
||||
permitPool *physical.PermitPool
|
||||
conf map[string]string
|
||||
redirectHost string
|
||||
redirectPort int64
|
||||
haEnabled bool
|
||||
}
|
||||
|
||||
// NewMySQLBackend constructs a MySQL backend using the given API client and
|
||||
// server address and credential for accessing mysql database.
|
||||
func NewMySQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
|
||||
var err error
|
||||
|
||||
db, err := NewMySQLClient(conf, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
database := conf["database"]
|
||||
if database == "" {
|
||||
database = "vault"
|
||||
}
|
||||
table := conf["table"]
|
||||
if table == "" {
|
||||
table = "vault"
|
||||
}
|
||||
|
||||
err = validateDBTable(database, table)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dbTable := fmt.Sprintf("`%s`.`%s`", database, table)
|
||||
|
||||
maxParStr, ok := conf["max_parallel"]
|
||||
var maxParInt int
|
||||
if ok {
|
||||
maxParInt, err = strconv.Atoi(maxParStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err)
|
||||
}
|
||||
if logger.IsDebug() {
|
||||
logger.Debug("max_parallel set", "max_parallel", maxParInt)
|
||||
}
|
||||
} else {
|
||||
maxParInt = physical.DefaultParallelOperations
|
||||
}
|
||||
|
||||
// Check schema exists
|
||||
var schemaExist bool
|
||||
schemaRows, err := db.Query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?", database)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check mysql schema exist: %w", err)
|
||||
}
|
||||
defer schemaRows.Close()
|
||||
schemaExist = schemaRows.Next()
|
||||
|
||||
// Check table exists
|
||||
var tableExist bool
|
||||
tableRows, err := db.Query("SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_NAME = ? AND TABLE_SCHEMA = ?", table, database)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check mysql table exist: %w", err)
|
||||
}
|
||||
defer tableRows.Close()
|
||||
tableExist = tableRows.Next()
|
||||
|
||||
// Create the required database if it doesn't exists.
|
||||
if !schemaExist {
|
||||
if _, err := db.Exec("CREATE DATABASE IF NOT EXISTS `" + database + "`"); err != nil {
|
||||
return nil, fmt.Errorf("failed to create mysql database: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create the required table if it doesn't exists.
|
||||
if !tableExist {
|
||||
create_query := "CREATE TABLE IF NOT EXISTS " + dbTable +
|
||||
" (vault_key varbinary(3072), vault_value mediumblob, PRIMARY KEY (vault_key))"
|
||||
if _, err := db.Exec(create_query); err != nil {
|
||||
return nil, fmt.Errorf("failed to create mysql table: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Default value for ha_enabled
|
||||
haEnabledStr, ok := conf["ha_enabled"]
|
||||
if !ok {
|
||||
haEnabledStr = "false"
|
||||
}
|
||||
haEnabled, err := strconv.ParseBool(haEnabledStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("value [%v] of 'ha_enabled' could not be understood", haEnabledStr)
|
||||
}
|
||||
|
||||
locktable, ok := conf["lock_table"]
|
||||
if !ok {
|
||||
locktable = table + "_lock"
|
||||
}
|
||||
|
||||
dbLockTable := "`" + database + "`.`" + locktable + "`"
|
||||
|
||||
// Only create lock table if ha_enabled is true
|
||||
if haEnabled {
|
||||
// Check table exists
|
||||
var lockTableExist bool
|
||||
lockTableRows, err := db.Query("SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_NAME = ? AND TABLE_SCHEMA = ?", locktable, database)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check mysql table exist: %w", err)
|
||||
}
|
||||
defer lockTableRows.Close()
|
||||
lockTableExist = lockTableRows.Next()
|
||||
|
||||
// Create the required table if it doesn't exists.
|
||||
if !lockTableExist {
|
||||
create_query := "CREATE TABLE IF NOT EXISTS " + dbLockTable +
|
||||
" (node_job varbinary(512), current_leader varbinary(512), PRIMARY KEY (node_job))"
|
||||
if _, err := db.Exec(create_query); err != nil {
|
||||
return nil, fmt.Errorf("failed to create mysql table: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Setup the backend.
|
||||
m := &MySQLBackend{
|
||||
dbTable: dbTable,
|
||||
dbLockTable: dbLockTable,
|
||||
client: db,
|
||||
statements: make(map[string]*sql.Stmt),
|
||||
logger: logger,
|
||||
permitPool: physical.NewPermitPool(maxParInt),
|
||||
conf: conf,
|
||||
haEnabled: haEnabled,
|
||||
}
|
||||
|
||||
// Prepare all the statements required
|
||||
statements := map[string]string{
|
||||
"put": "INSERT INTO " + dbTable +
|
||||
" VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)",
|
||||
"get": "SELECT vault_value FROM " + dbTable + " WHERE vault_key = ?",
|
||||
"delete": "DELETE FROM " + dbTable + " WHERE vault_key = ?",
|
||||
"list": "SELECT vault_key FROM " + dbTable + " WHERE vault_key LIKE ?",
|
||||
}
|
||||
|
||||
// Only prepare ha-related statements if we need them
|
||||
if haEnabled {
|
||||
statements["get_lock"] = "SELECT current_leader FROM " + dbLockTable + " WHERE node_job = ?"
|
||||
statements["used_lock"] = "SELECT IS_USED_LOCK(?)"
|
||||
}
|
||||
|
||||
for name, query := range statements {
|
||||
if err := m.prepare(name, query); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// validateDBTable to prevent SQL injection attacks. This ensures that the database and table names only have valid
|
||||
// characters in them. MySQL allows for more characters that this will allow, but there isn't an easy way of
|
||||
// representing the full Unicode Basic Multilingual Plane to check against.
|
||||
// https://dev.mysql.com/doc/refman/5.7/en/identifiers.html
|
||||
func validateDBTable(db, table string) (err error) {
|
||||
merr := &multierror.Error{}
|
||||
merr = multierror.Append(merr, wrapErr("invalid database: %w", validate(db)))
|
||||
merr = multierror.Append(merr, wrapErr("invalid table: %w", validate(table)))
|
||||
return merr.ErrorOrNil()
|
||||
}
|
||||
|
||||
func validate(name string) (err error) {
|
||||
if name == "" {
|
||||
return fmt.Errorf("missing name")
|
||||
}
|
||||
// From: https://dev.mysql.com/doc/refman/5.7/en/identifiers.html
|
||||
// - Permitted characters in quoted identifiers include the full Unicode Basic Multilingual Plane (BMP), except U+0000:
|
||||
// ASCII: U+0001 .. U+007F
|
||||
// Extended: U+0080 .. U+FFFF
|
||||
// - ASCII NUL (U+0000) and supplementary characters (U+10000 and higher) are not permitted in quoted or unquoted identifiers.
|
||||
// - Identifiers may begin with a digit but unless quoted may not consist solely of digits.
|
||||
// - Database, table, and column names cannot end with space characters.
|
||||
//
|
||||
// We are explicitly excluding all space characters (it's easier to deal with)
|
||||
// The name will be quoted, so the all-digit requirement doesn't apply
|
||||
runes := []rune(name)
|
||||
validationErr := fmt.Errorf("invalid character found: can only include printable, non-space characters between [0x0001-0xFFFF]")
|
||||
for _, r := range runes {
|
||||
// U+0000 Explicitly disallowed
|
||||
if r == 0x0000 {
|
||||
return fmt.Errorf("invalid character: cannot include 0x0000")
|
||||
}
|
||||
// Cannot be above 0xFFFF
|
||||
if r > 0xFFFF {
|
||||
return fmt.Errorf("invalid character: cannot include any characters above 0xFFFF")
|
||||
}
|
||||
if r == '`' {
|
||||
return fmt.Errorf("invalid character: cannot include '`' character")
|
||||
}
|
||||
if r == '\'' || r == '"' {
|
||||
return fmt.Errorf("invalid character: cannot include quotes")
|
||||
}
|
||||
// We are excluding non-printable characters (not mentioned in the docs)
|
||||
if !unicode.IsPrint(r) {
|
||||
return validationErr
|
||||
}
|
||||
// We are excluding space characters (not mentioned in the docs)
|
||||
if unicode.IsSpace(r) {
|
||||
return validationErr
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func wrapErr(message string, err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf(message, err)
|
||||
}
|
||||
|
||||
func NewMySQLClient(conf map[string]string, logger log.Logger) (*sql.DB, error) {
|
||||
var err error
|
||||
|
||||
// Get the MySQL credentials to perform read/write operations.
|
||||
username, ok := conf["username"]
|
||||
if !ok || username == "" {
|
||||
return nil, fmt.Errorf("missing username")
|
||||
}
|
||||
password, ok := conf["password"]
|
||||
if !ok || password == "" {
|
||||
return nil, fmt.Errorf("missing password")
|
||||
}
|
||||
|
||||
// Get or set MySQL server address. Defaults to localhost and default port(3306)
|
||||
address, ok := conf["address"]
|
||||
if !ok {
|
||||
address = "127.0.0.1:3306"
|
||||
}
|
||||
|
||||
maxIdleConnStr, ok := conf["max_idle_connections"]
|
||||
var maxIdleConnInt int
|
||||
if ok {
|
||||
maxIdleConnInt, err = strconv.Atoi(maxIdleConnStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed parsing max_idle_connections parameter: %w", err)
|
||||
}
|
||||
if logger.IsDebug() {
|
||||
logger.Debug("max_idle_connections set", "max_idle_connections", maxIdleConnInt)
|
||||
}
|
||||
}
|
||||
|
||||
maxConnLifeStr, ok := conf["max_connection_lifetime"]
|
||||
var maxConnLifeInt int
|
||||
if ok {
|
||||
maxConnLifeInt, err = strconv.Atoi(maxConnLifeStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed parsing max_connection_lifetime parameter: %w", err)
|
||||
}
|
||||
if logger.IsDebug() {
|
||||
logger.Debug("max_connection_lifetime set", "max_connection_lifetime", maxConnLifeInt)
|
||||
}
|
||||
}
|
||||
|
||||
maxParStr, ok := conf["max_parallel"]
|
||||
var maxParInt int
|
||||
if ok {
|
||||
maxParInt, err = strconv.Atoi(maxParStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err)
|
||||
}
|
||||
if logger.IsDebug() {
|
||||
logger.Debug("max_parallel set", "max_parallel", maxParInt)
|
||||
}
|
||||
} else {
|
||||
maxParInt = physical.DefaultParallelOperations
|
||||
}
|
||||
|
||||
dsnParams := url.Values{}
|
||||
tlsCaFile, tlsOk := conf["tls_ca_file"]
|
||||
if tlsOk {
|
||||
if err := setupMySQLTLSConfig(tlsCaFile); err != nil {
|
||||
return nil, fmt.Errorf("failed register TLS config: %w", err)
|
||||
}
|
||||
|
||||
dsnParams.Add("tls", mysqlTLSKey)
|
||||
}
|
||||
ptAllowed, ptOk := conf["plaintext_connection_allowed"]
|
||||
if !(ptOk && strings.ToLower(ptAllowed) == "true") && !tlsOk {
|
||||
logger.Warn("No TLS specified, credentials will be sent in plaintext. To mute this warning add 'plaintext_connection_allowed' with a true value to your MySQL configuration in your config file.")
|
||||
}
|
||||
|
||||
// Create MySQL handle for the database.
|
||||
dsn := username + ":" + password + "@tcp(" + address + ")/?" + dsnParams.Encode()
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to mysql: %w", err)
|
||||
}
|
||||
db.SetMaxOpenConns(maxParInt)
|
||||
if maxIdleConnInt != 0 {
|
||||
db.SetMaxIdleConns(maxIdleConnInt)
|
||||
}
|
||||
if maxConnLifeInt != 0 {
|
||||
db.SetConnMaxLifetime(time.Duration(maxConnLifeInt) * time.Second)
|
||||
}
|
||||
|
||||
return db, err
|
||||
}
|
||||
|
||||
// prepare is a helper to prepare a query for future execution
|
||||
func (m *MySQLBackend) prepare(name, query string) error {
|
||||
stmt, err := m.client.Prepare(query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prepare %q: %w", name, err)
|
||||
}
|
||||
m.statements[name] = stmt
|
||||
return nil
|
||||
}
|
||||
|
||||
// Put is used to insert or update an entry.
|
||||
func (m *MySQLBackend) Put(ctx context.Context, entry *physical.Entry) error {
|
||||
defer metrics.MeasureSince([]string{"mysql", "put"}, time.Now())
|
||||
|
||||
m.permitPool.Acquire()
|
||||
defer m.permitPool.Release()
|
||||
|
||||
_, err := m.statements["put"].Exec(entry.Key, entry.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get is used to fetch an entry.
|
||||
func (m *MySQLBackend) Get(ctx context.Context, key string) (*physical.Entry, error) {
|
||||
defer metrics.MeasureSince([]string{"mysql", "get"}, time.Now())
|
||||
|
||||
m.permitPool.Acquire()
|
||||
defer m.permitPool.Release()
|
||||
|
||||
var result []byte
|
||||
err := m.statements["get"].QueryRow(key).Scan(&result)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ent := &physical.Entry{
|
||||
Key: key,
|
||||
Value: result,
|
||||
}
|
||||
return ent, nil
|
||||
}
|
||||
|
||||
// Delete is used to permanently delete an entry
|
||||
func (m *MySQLBackend) Delete(ctx context.Context, key string) error {
|
||||
defer metrics.MeasureSince([]string{"mysql", "delete"}, time.Now())
|
||||
|
||||
m.permitPool.Acquire()
|
||||
defer m.permitPool.Release()
|
||||
|
||||
_, err := m.statements["delete"].Exec(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// List is used to list all the keys under a given
|
||||
// prefix, up to the next prefix.
|
||||
func (m *MySQLBackend) List(ctx context.Context, prefix string) ([]string, error) {
|
||||
defer metrics.MeasureSince([]string{"mysql", "list"}, time.Now())
|
||||
|
||||
m.permitPool.Acquire()
|
||||
defer m.permitPool.Release()
|
||||
|
||||
// Add the % wildcard to the prefix to do the prefix search
|
||||
likePrefix := prefix + "%"
|
||||
rows, err := m.statements["list"].Query(likePrefix)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute statement: %w", err)
|
||||
}
|
||||
|
||||
var keys []string
|
||||
for rows.Next() {
|
||||
var key string
|
||||
err = rows.Scan(&key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan rows: %w", err)
|
||||
}
|
||||
|
||||
key = strings.TrimPrefix(key, prefix)
|
||||
if i := strings.Index(key, "/"); i == -1 {
|
||||
// Add objects only from the current 'folder'
|
||||
keys = append(keys, key)
|
||||
} else if i != -1 {
|
||||
// Add truncated 'folder' paths
|
||||
keys = strutil.AppendIfMissing(keys, string(key[:i+1]))
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(keys)
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// LockWith is used for mutual exclusion based on the given key.
|
||||
func (m *MySQLBackend) LockWith(key, value string) (physical.Lock, error) {
|
||||
l := &MySQLHALock{
|
||||
in: m,
|
||||
key: key,
|
||||
value: value,
|
||||
logger: m.logger,
|
||||
}
|
||||
return l, nil
|
||||
}
|
||||
|
||||
func (m *MySQLBackend) HAEnabled() bool {
|
||||
return m.haEnabled
|
||||
}
|
||||
|
||||
// MySQLHALock is a MySQL Lock implementation for the HABackend
|
||||
type MySQLHALock struct {
|
||||
in *MySQLBackend
|
||||
key string
|
||||
value string
|
||||
logger log.Logger
|
||||
|
||||
held bool
|
||||
localLock sync.Mutex
|
||||
leaderCh chan struct{}
|
||||
stopCh <-chan struct{}
|
||||
lock *MySQLLock
|
||||
}
|
||||
|
||||
func (i *MySQLHALock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) {
|
||||
i.localLock.Lock()
|
||||
defer i.localLock.Unlock()
|
||||
if i.held {
|
||||
return nil, fmt.Errorf("lock already held")
|
||||
}
|
||||
|
||||
// Attempt an async acquisition
|
||||
didLock := make(chan struct{})
|
||||
failLock := make(chan error, 1)
|
||||
releaseCh := make(chan bool, 1)
|
||||
go i.attemptLock(i.key, i.value, didLock, failLock, releaseCh)
|
||||
|
||||
// Wait for lock acquisition, failure, or shutdown
|
||||
select {
|
||||
case <-didLock:
|
||||
releaseCh <- false
|
||||
case err := <-failLock:
|
||||
return nil, err
|
||||
case <-stopCh:
|
||||
releaseCh <- true
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Create the leader channel
|
||||
i.held = true
|
||||
i.leaderCh = make(chan struct{})
|
||||
|
||||
go i.monitorLock(i.leaderCh)
|
||||
|
||||
i.stopCh = stopCh
|
||||
|
||||
return i.leaderCh, nil
|
||||
}
|
||||
|
||||
func (i *MySQLHALock) attemptLock(key, value string, didLock chan struct{}, failLock chan error, releaseCh chan bool) {
|
||||
lock, err := NewMySQLLock(i.in, i.logger, key, value)
|
||||
if err != nil {
|
||||
failLock <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Set node value
|
||||
i.lock = lock
|
||||
|
||||
err = lock.Lock()
|
||||
if err != nil {
|
||||
failLock <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Signal that lock is held
|
||||
close(didLock)
|
||||
|
||||
// Handle an early abort
|
||||
release := <-releaseCh
|
||||
if release {
|
||||
lock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (i *MySQLHALock) monitorLock(leaderCh chan struct{}) {
|
||||
for {
|
||||
// The only way to lose this lock is if someone is
|
||||
// logging into the DB and altering system tables or you lose a connection in
|
||||
// which case you will lose the lock anyway.
|
||||
err := i.hasLock(i.key)
|
||||
if err != nil {
|
||||
// Somehow we lost the lock.... likely because the connection holding
|
||||
// the lock was closed or someone was playing around with the locks in the DB.
|
||||
close(leaderCh)
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func (i *MySQLHALock) Unlock() error {
|
||||
i.localLock.Lock()
|
||||
defer i.localLock.Unlock()
|
||||
if !i.held {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := i.lock.Unlock()
|
||||
|
||||
if err == nil {
|
||||
i.held = false
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// hasLock will check if a lock is held by checking the current lock id against our known ID.
|
||||
func (i *MySQLHALock) hasLock(key string) error {
|
||||
var result sql.NullInt64
|
||||
err := i.in.statements["used_lock"].QueryRow(key).Scan(&result)
|
||||
if err == sql.ErrNoRows || !result.Valid {
|
||||
// This is not an error to us since it just means the lock isn't held
|
||||
return nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// IS_USED_LOCK will return the ID of the connection that created the lock.
|
||||
if result.Int64 != GlobalLockID {
|
||||
return ErrLockHeld
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *MySQLHALock) GetLeader() (string, error) {
|
||||
defer metrics.MeasureSince([]string{"mysql", "lock_get"}, time.Now())
|
||||
var result string
|
||||
err := i.in.statements["get_lock"].QueryRow("leader").Scan(&result)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (i *MySQLHALock) Value() (bool, string, error) {
|
||||
leaderkey, err := i.GetLeader()
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
return true, leaderkey, err
|
||||
}
|
||||
|
||||
// MySQLLock provides an easy way to grab and release mysql
|
||||
// locks using the built in GET_LOCK function. Note that these
|
||||
// locks are released when you lose connection to the server.
|
||||
type MySQLLock struct {
|
||||
parentConn *MySQLBackend
|
||||
in *sql.DB
|
||||
logger log.Logger
|
||||
statements map[string]*sql.Stmt
|
||||
key string
|
||||
value string
|
||||
}
|
||||
|
||||
// Errors specific to trying to grab a lock in MySQL
|
||||
var (
|
||||
// This is the GlobalLockID for checking if the lock we got is still the current lock
|
||||
GlobalLockID int64
|
||||
// ErrLockHeld is returned when another vault instance already has a lock held for the given key.
|
||||
ErrLockHeld = errors.New("mysql: lock already held")
|
||||
// ErrUnlockFailed
|
||||
ErrUnlockFailed = errors.New("mysql: unable to release lock, already released or not held by this session")
|
||||
// You were unable to update that you are the new leader in the DB
|
||||
ErrClaimFailed = errors.New("mysql: unable to update DB with new leader information")
|
||||
// Error to throw if between getting the lock and checking the ID of it we lost it.
|
||||
ErrSettingGlobalID = errors.New("mysql: getting global lock id failed")
|
||||
)
|
||||
|
||||
// NewMySQLLock helper function
|
||||
func NewMySQLLock(in *MySQLBackend, l log.Logger, key, value string) (*MySQLLock, error) {
|
||||
// Create a new MySQL connection so we can close this and have no effect on
|
||||
// the rest of the MySQL backend and any cleanup that might need to be done.
|
||||
conn, _ := NewMySQLClient(in.conf, in.logger)
|
||||
|
||||
m := &MySQLLock{
|
||||
parentConn: in,
|
||||
in: conn,
|
||||
logger: l,
|
||||
statements: make(map[string]*sql.Stmt),
|
||||
key: key,
|
||||
value: value,
|
||||
}
|
||||
|
||||
statements := map[string]string{
|
||||
"put": "INSERT INTO " + in.dbLockTable +
|
||||
" VALUES( ?, ? ) ON DUPLICATE KEY UPDATE current_leader=VALUES(current_leader)",
|
||||
}
|
||||
|
||||
for name, query := range statements {
|
||||
if err := m.prepare(name, query); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// prepare is a helper to prepare a query for future execution
|
||||
func (m *MySQLLock) prepare(name, query string) error {
|
||||
stmt, err := m.in.Prepare(query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prepare %q: %w", name, err)
|
||||
}
|
||||
m.statements[name] = stmt
|
||||
return nil
|
||||
}
|
||||
|
||||
// update the current cluster leader in the DB. This is used so
|
||||
// we can tell the servers in standby who the active leader is.
|
||||
func (i *MySQLLock) becomeLeader() error {
|
||||
_, err := i.statements["put"].Exec("leader", i.value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Lock will try to get a lock for an indefinite amount of time
|
||||
// based on the given key that has been requested.
|
||||
func (i *MySQLLock) Lock() error {
|
||||
defer metrics.MeasureSince([]string{"mysql", "get_lock"}, time.Now())
|
||||
|
||||
// Lock timeout math.MaxInt32 instead of -1 solves compatibility issues with
|
||||
// different MySQL flavours i.e. MariaDB
|
||||
rows, err := i.in.Query("SELECT GET_LOCK(?, ?), IS_USED_LOCK(?)", i.key, math.MaxInt32, i.key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
rows.Next()
|
||||
var lock sql.NullInt64
|
||||
var connectionID sql.NullInt64
|
||||
rows.Scan(&lock, &connectionID)
|
||||
|
||||
if rows.Err() != nil {
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
// 1 is returned from GET_LOCK if it was able to get the lock
|
||||
// 0 if it failed and NULL if some strange error happened.
|
||||
// https://dev.mysql.com/doc/refman/8.0/en/miscellaneous-functions.html#function_get-lock
|
||||
if !lock.Valid || lock.Int64 != 1 {
|
||||
return ErrLockHeld
|
||||
}
|
||||
|
||||
// Since we have the lock alert the rest of the cluster
|
||||
// that we are now the active leader.
|
||||
err = i.becomeLeader()
|
||||
if err != nil {
|
||||
return ErrLockHeld
|
||||
}
|
||||
|
||||
// This will return the connection ID of NULL if an error happens
|
||||
// https://dev.mysql.com/doc/refman/8.0/en/miscellaneous-functions.html#function_is-used-lock
|
||||
if !connectionID.Valid {
|
||||
return ErrSettingGlobalID
|
||||
}
|
||||
|
||||
GlobalLockID = connectionID.Int64
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unlock just closes the connection. This is because closing the MySQL connection
|
||||
// is a 100% reliable way to close the lock. If you just release the lock you must
|
||||
// do it from the same mysql connection_id that you originally created it from. This
|
||||
// is a huge hastle and I actually couldn't find a clean way to do this although one
|
||||
// likely does exist. Closing the connection however ensures we don't ever get into a
|
||||
// state where we try to release the lock and it hangs it is also much less code.
|
||||
func (i *MySQLLock) Unlock() error {
|
||||
err := i.in.Close()
|
||||
if err != nil {
|
||||
return ErrUnlockFailed
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Establish a TLS connection with a given CA certificate
|
||||
// Register a tsl.Config associated with the same key as the dns param from sql.Open
|
||||
// foo:bar@tcp(127.0.0.1:3306)/dbname?tls=default
|
||||
func setupMySQLTLSConfig(tlsCaFile string) error {
|
||||
rootCertPool := x509.NewCertPool()
|
||||
|
||||
pem, err := ioutil.ReadFile(tlsCaFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
|
||||
return err
|
||||
}
|
||||
|
||||
err = mysql.RegisterTLSConfig(mysqlTLSKey, &tls.Config{
|
||||
RootCAs: rootCertPool,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,346 +0,0 @@
|
|||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/sdk/helper/logging"
|
||||
"github.com/hashicorp/vault/sdk/physical"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
mysql "github.com/go-sql-driver/mysql"
|
||||
|
||||
mysqlhelper "github.com/hashicorp/vault/helper/testhelpers/mysql"
|
||||
)
|
||||
|
||||
func TestMySQLPlaintextCatch(t *testing.T) {
|
||||
address := os.Getenv("MYSQL_ADDR")
|
||||
if address == "" {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
database := os.Getenv("MYSQL_DB")
|
||||
if database == "" {
|
||||
database = "test"
|
||||
}
|
||||
|
||||
table := os.Getenv("MYSQL_TABLE")
|
||||
if table == "" {
|
||||
table = "test"
|
||||
}
|
||||
|
||||
username := os.Getenv("MYSQL_USERNAME")
|
||||
password := os.Getenv("MYSQL_PASSWORD")
|
||||
|
||||
// Run vault tests
|
||||
var buf bytes.Buffer
|
||||
log.DefaultOutput = &buf
|
||||
|
||||
logger := logging.NewVaultLogger(log.Debug)
|
||||
|
||||
NewMySQLBackend(map[string]string{
|
||||
"address": address,
|
||||
"database": database,
|
||||
"table": table,
|
||||
"username": username,
|
||||
"password": password,
|
||||
"plaintext_connection_allowed": "false",
|
||||
}, logger)
|
||||
|
||||
str := buf.String()
|
||||
dataIdx := strings.IndexByte(str, ' ')
|
||||
rest := str[dataIdx+1:]
|
||||
|
||||
if !strings.Contains(rest, "credentials will be sent in plaintext") {
|
||||
t.Fatalf("No warning of plaintext credentials occurred")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQLBackend(t *testing.T) {
|
||||
address := os.Getenv("MYSQL_ADDR")
|
||||
if address == "" {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
database := os.Getenv("MYSQL_DB")
|
||||
if database == "" {
|
||||
database = "test"
|
||||
}
|
||||
|
||||
table := os.Getenv("MYSQL_TABLE")
|
||||
if table == "" {
|
||||
table = "test"
|
||||
}
|
||||
|
||||
username := os.Getenv("MYSQL_USERNAME")
|
||||
password := os.Getenv("MYSQL_PASSWORD")
|
||||
|
||||
// Run vault tests
|
||||
logger := logging.NewVaultLogger(log.Debug)
|
||||
|
||||
b, err := NewMySQLBackend(map[string]string{
|
||||
"address": address,
|
||||
"database": database,
|
||||
"table": table,
|
||||
"username": username,
|
||||
"password": password,
|
||||
"plaintext_connection_allowed": "true",
|
||||
"max_connection_lifetime": "1",
|
||||
}, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create new backend: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
mysql := b.(*MySQLBackend)
|
||||
_, err := mysql.client.Exec("DROP TABLE IF EXISTS " + mysql.dbTable + " ," + mysql.dbLockTable)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to drop table: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
physical.ExerciseBackend(t, b)
|
||||
physical.ExerciseBackend_ListPrefix(t, b)
|
||||
}
|
||||
|
||||
func TestMySQLHABackend(t *testing.T) {
|
||||
address := os.Getenv("MYSQL_ADDR")
|
||||
if address == "" {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
database := os.Getenv("MYSQL_DB")
|
||||
if database == "" {
|
||||
database = "test"
|
||||
}
|
||||
|
||||
table := os.Getenv("MYSQL_TABLE")
|
||||
if table == "" {
|
||||
table = "test"
|
||||
}
|
||||
|
||||
username := os.Getenv("MYSQL_USERNAME")
|
||||
password := os.Getenv("MYSQL_PASSWORD")
|
||||
|
||||
// Run vault tests
|
||||
logger := logging.NewVaultLogger(log.Debug)
|
||||
config := map[string]string{
|
||||
"address": address,
|
||||
"database": database,
|
||||
"table": table,
|
||||
"username": username,
|
||||
"password": password,
|
||||
"ha_enabled": "true",
|
||||
"plaintext_connection_allowed": "true",
|
||||
}
|
||||
|
||||
b, err := NewMySQLBackend(config, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create new backend: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
mysql := b.(*MySQLBackend)
|
||||
_, err := mysql.client.Exec("DROP TABLE IF EXISTS " + mysql.dbTable + " ," + mysql.dbLockTable)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to drop table: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
b2, err := NewMySQLBackend(config, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create new backend: %v", err)
|
||||
}
|
||||
|
||||
physical.ExerciseHABackend(t, b.(physical.HABackend), b2.(physical.HABackend))
|
||||
}
|
||||
|
||||
// TestMySQLHABackend_LockFailPanic is a regression test for the panic shown in
|
||||
// https://github.com/hashicorp/vault/issues/8203 and patched in
|
||||
// https://github.com/hashicorp/vault/pull/8229
|
||||
func TestMySQLHABackend_LockFailPanic(t *testing.T) {
|
||||
cleanup, connURL := mysqlhelper.PrepareTestContainer(t, false, "secret")
|
||||
|
||||
cfg, err := mysql.ParseDSN(connURL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := mysqlhelper.TestCredsExist(t, connURL, cfg.User, cfg.Passwd); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
table := "test"
|
||||
logger := logging.NewVaultLogger(log.Debug)
|
||||
config := map[string]string{
|
||||
"address": cfg.Addr,
|
||||
"database": cfg.DBName,
|
||||
"table": table,
|
||||
"username": cfg.User,
|
||||
"password": cfg.Passwd,
|
||||
"ha_enabled": "true",
|
||||
"plaintext_connection_allowed": "true",
|
||||
}
|
||||
|
||||
b, err := NewMySQLBackend(config, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create new backend: %v", err)
|
||||
}
|
||||
|
||||
b2, err := NewMySQLBackend(config, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create new backend: %v", err)
|
||||
}
|
||||
|
||||
b1ha := b.(physical.HABackend)
|
||||
b2ha := b2.(physical.HABackend)
|
||||
|
||||
// Copied from ExerciseHABackend - ensuring things are normal at this point
|
||||
// Get the lock
|
||||
lock, err := b1ha.LockWith("foo", "bar")
|
||||
if err != nil {
|
||||
t.Fatalf("initial lock: %v", err)
|
||||
}
|
||||
|
||||
// Attempt to lock
|
||||
leaderCh, err := lock.Lock(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("lock attempt 1: %v", err)
|
||||
}
|
||||
if leaderCh == nil {
|
||||
t.Fatalf("missing leaderCh")
|
||||
}
|
||||
|
||||
// Check the value
|
||||
held, val, err := lock.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !held {
|
||||
t.Errorf("should be held")
|
||||
}
|
||||
if val != "bar" {
|
||||
t.Errorf("expected value bar: %v", err)
|
||||
}
|
||||
|
||||
// Second acquisition should fail
|
||||
lock2, err := b2ha.LockWith("foo", "baz")
|
||||
if err != nil {
|
||||
t.Fatalf("lock 2: %v", err)
|
||||
}
|
||||
|
||||
stopCh := make(chan struct{})
|
||||
time.AfterFunc(10*time.Second, func() {
|
||||
close(stopCh)
|
||||
})
|
||||
|
||||
// Attempt to lock - can't lock because lock1 is held - this is normal
|
||||
leaderCh2, err := lock2.Lock(stopCh)
|
||||
if err != nil {
|
||||
t.Fatalf("stop lock 2: %v", err)
|
||||
}
|
||||
if leaderCh2 != nil {
|
||||
t.Errorf("should not have gotten leaderCh: %v", leaderCh2)
|
||||
}
|
||||
// end normal
|
||||
|
||||
// Clean up the database. When Lock() is called, a new connection is created
|
||||
// using the configuration. If that connection cannot be created, there was a
|
||||
// panic due to not returning with the connection error. Here we intentionally
|
||||
// break the config for b2, so a new connection can't be made, which would
|
||||
// trigger the panic shown in https://github.com/hashicorp/vault/issues/8203
|
||||
cleanup()
|
||||
|
||||
stopCh2 := make(chan struct{})
|
||||
time.AfterFunc(10*time.Second, func() {
|
||||
close(stopCh2)
|
||||
})
|
||||
leaderCh2, err = lock2.Lock(stopCh2)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got none, leaderCh2=%v", leaderCh2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateDBTable(t *testing.T) {
|
||||
type testCase struct {
|
||||
database string
|
||||
table string
|
||||
expectErr bool
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
"empty database & table": {"", "", true},
|
||||
"empty database": {"", "a", true},
|
||||
"empty table": {"a", "", true},
|
||||
"ascii database": {"abcde", "a", false},
|
||||
"ascii table": {"a", "abcde", false},
|
||||
"ascii database & table": {"abcde", "abcde", false},
|
||||
"only whitespace db": {" ", "a", true},
|
||||
"only whitespace table": {"a", " ", true},
|
||||
"whitespace prefix db": {" bcde", "a", true},
|
||||
"whitespace middle db": {"ab de", "a", true},
|
||||
"whitespace suffix db": {"abcd ", "a", true},
|
||||
"whitespace prefix table": {"a", " bcde", true},
|
||||
"whitespace middle table": {"a", "ab de", true},
|
||||
"whitespace suffix table": {"a", "abcd ", true},
|
||||
"backtick prefix db": {"`bcde", "a", true},
|
||||
"backtick middle db": {"ab`de", "a", true},
|
||||
"backtick suffix db": {"abcd`", "a", true},
|
||||
"backtick prefix table": {"a", "`bcde", true},
|
||||
"backtick middle table": {"a", "ab`de", true},
|
||||
"backtick suffix table": {"a", "abcd`", true},
|
||||
"single quote prefix db": {"'bcde", "a", true},
|
||||
"single quote middle db": {"ab'de", "a", true},
|
||||
"single quote suffix db": {"abcd'", "a", true},
|
||||
"single quote prefix table": {"a", "'bcde", true},
|
||||
"single quote middle table": {"a", "ab'de", true},
|
||||
"single quote suffix table": {"a", "abcd'", true},
|
||||
"double quote prefix db": {`"bcde`, "a", true},
|
||||
"double quote middle db": {`ab"de`, "a", true},
|
||||
"double quote suffix db": {`abcd"`, "a", true},
|
||||
"double quote prefix table": {"a", `"bcde`, true},
|
||||
"double quote middle table": {"a", `ab"de`, true},
|
||||
"double quote suffix table": {"a", `abcd"`, true},
|
||||
"0x0000 prefix db": {str(0x0000, 'b', 'c'), "a", true},
|
||||
"0x0000 middle db": {str('a', 0x0000, 'c'), "a", true},
|
||||
"0x0000 suffix db": {str('a', 'b', 0x0000), "a", true},
|
||||
"0x0000 prefix table": {"a", str(0x0000, 'b', 'c'), true},
|
||||
"0x0000 middle table": {"a", str('a', 0x0000, 'c'), true},
|
||||
"0x0000 suffix table": {"a", str('a', 'b', 0x0000), true},
|
||||
"unicode > 0xFFFF prefix db": {str(0x10000, 'b', 'c'), "a", true},
|
||||
"unicode > 0xFFFF middle db": {str('a', 0x10000, 'c'), "a", true},
|
||||
"unicode > 0xFFFF suffix db": {str('a', 'b', 0x10000), "a", true},
|
||||
"unicode > 0xFFFF prefix table": {"a", str(0x10000, 'b', 'c'), true},
|
||||
"unicode > 0xFFFF middle table": {"a", str('a', 0x10000, 'c'), true},
|
||||
"unicode > 0xFFFF suffix table": {"a", str('a', 'b', 0x10000), true},
|
||||
"non-printable prefix db": {str(0x0001, 'b', 'c'), "a", true},
|
||||
"non-printable middle db": {str('a', 0x0001, 'c'), "a", true},
|
||||
"non-printable suffix db": {str('a', 'b', 0x0001), "a", true},
|
||||
"non-printable prefix table": {"a", str(0x0001, 'b', 'c'), true},
|
||||
"non-printable middle table": {"a", str('a', 0x0001, 'c'), true},
|
||||
"non-printable suffix table": {"a", str('a', 'b', 0x0001), true},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
err := validateDBTable(test.database, test.table)
|
||||
if test.expectErr && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Fatalf("no error expected, got: %s", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func str(r ...rune) string {
|
||||
return string(r)
|
||||
}
|
Loading…
Reference in New Issue
Block a user