1
0

Compare commits

...

3 Commits

Author SHA1 Message Date
bea345a84c remove storage/mysql 2024-07-01 14:20:25 +03:00
3fcad1ec13 remove cockroachdb 2024-07-01 12:29:08 +03:00
13230a9749 remove storage/postgresql 2024-07-01 12:27:51 +03:00
11 changed files with 0 additions and 3253 deletions

View File

@ -41,12 +41,9 @@ import (
logicalKv "github.com/hashicorp/vault-plugin-secrets-kv"
logicalDb "github.com/hashicorp/vault/builtin/logical/database"
physCockroachDB "github.com/hashicorp/vault/physical/cockroachdb"
physConsul "github.com/hashicorp/vault/physical/consul"
physFoundationDB "github.com/hashicorp/vault/physical/foundationdb"
physMySQL "github.com/hashicorp/vault/physical/mysql"
physOCI "github.com/hashicorp/vault/physical/oci"
physPostgreSQL "github.com/hashicorp/vault/physical/postgresql"
physRaft "github.com/hashicorp/vault/physical/raft"
physFile "github.com/hashicorp/vault/sdk/physical/file"
physInmem "github.com/hashicorp/vault/sdk/physical/inmem"
@ -168,7 +165,6 @@ var (
}
physicalBackends = map[string]physical.Factory{
"cockroachdb": physCockroachDB.NewCockroachDBBackend,
"consul": physConsul.NewConsulBackend,
"file_transactional": physFile.NewTransactionalFileBackend,
"file": physFile.NewFileBackend,
@ -177,9 +173,7 @@ var (
"inmem_transactional_ha": physInmem.NewTransactionalInmemHA,
"inmem_transactional": physInmem.NewTransactionalInmem,
"inmem": physInmem.NewInmem,
"mysql": physMySQL.NewMySQLBackend,
"oci": physOCI.NewBackend,
"postgresql": physPostgreSQL.NewPostgreSQLBackend,
"raft": physRaft.NewRaftBackend,
}

2
go.mod
View File

@ -33,7 +33,6 @@ require (
github.com/axiomhq/hyperloglog v0.0.0-20220105174342-98591331716a
github.com/cenkalti/backoff/v3 v3.2.2
github.com/chrismalek/oktasdk-go v0.0.0-20181212195951-3430665dfaa0
github.com/cockroachdb/cockroach-go v0.0.0-20181001143604-e0a95dfd547c
github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf
github.com/duosecurity/duo_api_golang v0.0.0-20190308151101-6c680f768e74
github.com/dustin/go-humanize v1.0.1
@ -295,7 +294,6 @@ require (
github.com/jackc/pgproto3/v2 v2.3.3 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgtype v1.14.0 // indirect
github.com/jackc/pgx v3.3.0+incompatible // indirect
github.com/jcmturner/aescts/v2 v2.0.0 // indirect
github.com/jcmturner/dnsutils/v2 v2.0.0 // indirect
github.com/jcmturner/gofork v1.7.6 // indirect

5
go.sum
View File

@ -1095,8 +1095,6 @@ github.com/cncf/xds/go v0.0.0-20230428030218-4003588d1b74/go.mod h1:eXthEFrGJvWH
github.com/cncf/xds/go v0.0.0-20230607035331-e9ce68804cb4/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I=
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
github.com/cockroachdb/cockroach-go v0.0.0-20181001143604-e0a95dfd547c h1:2zRrJWIt/f9c9HhNHAgrRgq0San5gRRUJTBXLkchal0=
github.com/cockroachdb/cockroach-go v0.0.0-20181001143604-e0a95dfd547c/go.mod h1:XGLbWH/ujMcbPbhZq52Nv6UrCghb1yGn//133kEsvDk=
github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa/go.mod h1:zn76sxSg3SzpJ0PPJaLDCu+Bu0Lg3sKTORVIj19EIF8=
github.com/cockroachdb/datadriven v0.0.0-20200714090401-bf6692d28da5/go.mod h1:h6jFvWxBdQXxjopDMZyH2UVceIRfR84bdzbkoKrsWNo=
github.com/cockroachdb/errors v1.2.4/go.mod h1:rQD95gz6FARkaKkQXUksEje/d9a6wBJoCr5oaCLELYA=
@ -2092,7 +2090,6 @@ github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9
github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8=
github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
github.com/jackc/fake v0.0.0-20150926172116-812a484cc733 h1:vr3AYkKovP8uR8AvSGGUK1IDqRa5lAAvEkZG1LKaCRc=
github.com/jackc/fake v0.0.0-20150926172116-812a484cc733/go.mod h1:WrMFNQdiFJ80sQsxDoMokWK1W5TQtxBFNpzWTD84ibQ=
github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA=
github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE=
@ -2128,7 +2125,6 @@ github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrU
github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM=
github.com/jackc/pgtype v1.14.0 h1:y+xUdabmyMkJLyApYuPj38mW+aAIqCe5uuBB51rH3Vw=
github.com/jackc/pgtype v1.14.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4=
github.com/jackc/pgx v3.3.0+incompatible h1:Wa90/+qsITBAPkAZjiByeIGHFcj3Ztu+VzrrIpHjL90=
github.com/jackc/pgx v3.3.0+incompatible/go.mod h1:0ZGrqGqkRlliWnWB4zKnWtjbSWbGkVEFm4TeybAXq+I=
github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y=
github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM=
@ -2683,7 +2679,6 @@ github.com/safchain/ethtool v0.0.0-20210803160452-9aa261dae9b1/go.mod h1:Z0q5wiB
github.com/safchain/ethtool v0.2.0/go.mod h1:WkKB1DnNtvsMlDmQ50sgwowDJV/hGbJSOvJoEXs1AJQ=
github.com/sasha-s/go-deadlock v0.2.0 h1:lMqc+fUb7RrFS3gQLtoQsJ7/6TV/pAIFvBsqX73DK8Y=
github.com/sasha-s/go-deadlock v0.2.0/go.mod h1:StQn567HiB1fF2yJ44N9au7wOhrPS3iZqiDbRupzT10=
github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww=
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw=
github.com/sclevine/spec v1.2.0/go.mod h1:W4J29eT/Kzv7/b9IWLB055Z+qvVC9vt0Arko24q7p+U=

View File

@ -1,356 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package cockroachdb
import (
"context"
"database/sql"
"fmt"
"sort"
"strconv"
"strings"
"time"
"unicode"
metrics "github.com/armon/go-metrics"
"github.com/cockroachdb/cockroach-go/crdb"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/sdk/physical"
// CockroachDB uses the Postgres SQL driver
_ "github.com/jackc/pgx/v4/stdlib"
)
// Verify CockroachDBBackend satisfies the correct interfaces
var (
_ physical.Backend = (*CockroachDBBackend)(nil)
_ physical.Transactional = (*CockroachDBBackend)(nil)
)
const (
defaultTableName = "vault_kv_store"
defaultHATableName = "vault_ha_locks"
)
// CockroachDBBackend Backend is a physical backend that stores data
// within a CockroachDB database.
type CockroachDBBackend struct {
table string
haTable string
client *sql.DB
rawStatements map[string]string
statements map[string]*sql.Stmt
rawHAStatements map[string]string
haStatements map[string]*sql.Stmt
logger log.Logger
permitPool *physical.PermitPool
haEnabled bool
}
// NewCockroachDBBackend constructs a CockroachDB backend using the given
// API client, server address, credentials, and database.
func NewCockroachDBBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
// Get the CockroachDB credentials to perform read/write operations.
connURL, ok := conf["connection_url"]
if !ok || connURL == "" {
return nil, fmt.Errorf("missing connection_url")
}
haEnabled := conf["ha_enabled"] == "true"
dbTable := conf["table"]
if dbTable == "" {
dbTable = defaultTableName
}
err := validateDBTable(dbTable)
if err != nil {
return nil, fmt.Errorf("invalid table: %w", err)
}
dbHATable, ok := conf["ha_table"]
if !ok {
dbHATable = defaultHATableName
}
err = validateDBTable(dbHATable)
if err != nil {
return nil, fmt.Errorf("invalid HA table: %w", err)
}
maxParStr, ok := conf["max_parallel"]
var maxParInt int
if ok {
maxParInt, err = strconv.Atoi(maxParStr)
if err != nil {
return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err)
}
if logger.IsDebug() {
logger.Debug("max_parallel set", "max_parallel", maxParInt)
}
}
// Create CockroachDB handle for the database.
db, err := sql.Open("pgx", connURL)
if err != nil {
return nil, fmt.Errorf("failed to connect to cockroachdb: %w", err)
}
// Create the required tables if they don't exist.
createQuery := "CREATE TABLE IF NOT EXISTS " + dbTable +
" (path STRING, value BYTES, PRIMARY KEY (path))"
if _, err := db.Exec(createQuery); err != nil {
return nil, fmt.Errorf("failed to create CockroachDB table: %w", err)
}
if haEnabled {
createHATableQuery := "CREATE TABLE IF NOT EXISTS " + dbHATable +
"(ha_key TEXT NOT NULL, " +
" ha_identity TEXT NOT NULL, " +
" ha_value TEXT, " +
" valid_until TIMESTAMP WITH TIME ZONE NOT NULL, " +
" CONSTRAINT ha_key PRIMARY KEY (ha_key) " +
");"
if _, err := db.Exec(createHATableQuery); err != nil {
return nil, fmt.Errorf("failed to create CockroachDB HA table: %w", err)
}
}
// Setup the backend
c := &CockroachDBBackend{
table: dbTable,
haTable: dbHATable,
client: db,
rawStatements: map[string]string{
"put": "INSERT INTO " + dbTable + " VALUES($1, $2)" +
" ON CONFLICT (path) DO " +
" UPDATE SET (path, value) = ($1, $2)",
"get": "SELECT value FROM " + dbTable + " WHERE path = $1",
"delete": "DELETE FROM " + dbTable + " WHERE path = $1",
"list": "SELECT path FROM " + dbTable + " WHERE path LIKE $1",
},
statements: make(map[string]*sql.Stmt),
rawHAStatements: map[string]string{
"get": "SELECT ha_value FROM " + dbHATable + " WHERE NOW() <= valid_until AND ha_key = $1",
"upsert": "INSERT INTO " + dbHATable + " as t (ha_identity, ha_key, ha_value, valid_until)" +
" VALUES ($1, $2, $3, NOW() + $4) " +
" ON CONFLICT (ha_key) DO " +
" UPDATE SET (ha_identity, ha_key, ha_value, valid_until) = ($1, $2, $3, NOW() + $4) " +
" WHERE (t.valid_until < NOW() AND t.ha_key = $2) OR " +
" (t.ha_identity = $1 AND t.ha_key = $2) ",
"delete": "DELETE FROM " + dbHATable + " WHERE ha_key = $1",
},
haStatements: make(map[string]*sql.Stmt),
logger: logger,
permitPool: physical.NewPermitPool(maxParInt),
haEnabled: haEnabled,
}
// Prepare all the statements required
for name, query := range c.rawStatements {
if err := c.prepare(c.statements, name, query); err != nil {
return nil, err
}
}
if haEnabled {
for name, query := range c.rawHAStatements {
if err := c.prepare(c.haStatements, name, query); err != nil {
return nil, err
}
}
}
return c, nil
}
// prepare is a helper to prepare a query for future execution.
func (c *CockroachDBBackend) prepare(statementMap map[string]*sql.Stmt, name, query string) error {
stmt, err := c.client.Prepare(query)
if err != nil {
return fmt.Errorf("failed to prepare %q: %w", name, err)
}
statementMap[name] = stmt
return nil
}
// Put is used to insert or update an entry.
func (c *CockroachDBBackend) Put(ctx context.Context, entry *physical.Entry) error {
defer metrics.MeasureSince([]string{"cockroachdb", "put"}, time.Now())
c.permitPool.Acquire()
defer c.permitPool.Release()
_, err := c.statements["put"].Exec(entry.Key, entry.Value)
if err != nil {
return err
}
return nil
}
// Get is used to fetch and entry.
func (c *CockroachDBBackend) Get(ctx context.Context, key string) (*physical.Entry, error) {
defer metrics.MeasureSince([]string{"cockroachdb", "get"}, time.Now())
c.permitPool.Acquire()
defer c.permitPool.Release()
var result []byte
err := c.statements["get"].QueryRow(key).Scan(&result)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
ent := &physical.Entry{
Key: key,
Value: result,
}
return ent, nil
}
// Delete is used to permanently delete an entry
func (c *CockroachDBBackend) Delete(ctx context.Context, key string) error {
defer metrics.MeasureSince([]string{"cockroachdb", "delete"}, time.Now())
c.permitPool.Acquire()
defer c.permitPool.Release()
_, err := c.statements["delete"].Exec(key)
if err != nil {
return err
}
return nil
}
// List is used to list all the keys under a given
// prefix, up to the next prefix.
func (c *CockroachDBBackend) List(ctx context.Context, prefix string) ([]string, error) {
defer metrics.MeasureSince([]string{"cockroachdb", "list"}, time.Now())
c.permitPool.Acquire()
defer c.permitPool.Release()
likePrefix := prefix + "%"
rows, err := c.statements["list"].Query(likePrefix)
if err != nil {
return nil, err
}
defer rows.Close()
var keys []string
for rows.Next() {
var key string
err = rows.Scan(&key)
if err != nil {
return nil, fmt.Errorf("failed to scan rows: %w", err)
}
key = strings.TrimPrefix(key, prefix)
if i := strings.Index(key, "/"); i == -1 {
// Add objects only from the current 'folder'
keys = append(keys, key)
} else if i != -1 {
// Add truncated 'folder' paths
keys = strutil.AppendIfMissing(keys, string(key[:i+1]))
}
}
sort.Strings(keys)
return keys, nil
}
// Transaction is used to run multiple entries via a transaction
func (c *CockroachDBBackend) Transaction(ctx context.Context, txns []*physical.TxnEntry) error {
defer metrics.MeasureSince([]string{"cockroachdb", "transaction"}, time.Now())
if len(txns) == 0 {
return nil
}
c.permitPool.Acquire()
defer c.permitPool.Release()
return crdb.ExecuteTx(context.Background(), c.client, nil, func(tx *sql.Tx) error {
return c.transaction(tx, txns)
})
}
func (c *CockroachDBBackend) transaction(tx *sql.Tx, txns []*physical.TxnEntry) error {
deleteStmt, err := tx.Prepare(c.rawStatements["delete"])
if err != nil {
return err
}
putStmt, err := tx.Prepare(c.rawStatements["put"])
if err != nil {
return err
}
for _, op := range txns {
switch op.Operation {
case physical.DeleteOperation:
_, err = deleteStmt.Exec(op.Entry.Key)
case physical.PutOperation:
_, err = putStmt.Exec(op.Entry.Key, op.Entry.Value)
default:
return fmt.Errorf("%q is not a supported transaction operation", op.Operation)
}
if err != nil {
return err
}
}
return nil
}
// validateDBTable against the CockroachDB rules for table names:
// https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#identifiers
//
// - All values that accept an identifier must:
// - Begin with a Unicode letter or an underscore (_). Subsequent characters can be letters,
// - underscores, digits (0-9), or dollar signs ($).
// - Not equal any SQL keyword unless the keyword is accepted by the element's syntax. For example,
// name accepts Unreserved or Column Name keywords.
//
// The docs do state that we can bypass these rules with double quotes, however I think it
// is safer to just require these rules across the board.
func validateDBTable(dbTable string) (err error) {
// Check if this is 'database.table' formatted. If so, split them apart and check the two
// parts from each other
split := strings.SplitN(dbTable, ".", 2)
if len(split) == 2 {
merr := &multierror.Error{}
merr = multierror.Append(merr, wrapErr("invalid database: %w", validateDBTable(split[0])))
merr = multierror.Append(merr, wrapErr("invalid table name: %w", validateDBTable(split[1])))
return merr.ErrorOrNil()
}
// Disallow SQL keywords as the table name
if sqlKeywords[strings.ToUpper(dbTable)] {
return fmt.Errorf("name must not be a SQL keyword")
}
runes := []rune(dbTable)
for i, r := range runes {
if i == 0 && !unicode.IsLetter(r) && r != '_' {
return fmt.Errorf("must use a letter or an underscore as the first character")
}
if !unicode.IsLetter(r) && r != '_' && !unicode.IsDigit(r) && r != '$' {
return fmt.Errorf("must only contain letters, underscores, digits, and dollar signs")
}
if r == '`' || r == '\'' || r == '"' {
return fmt.Errorf("cannot contain backticks, single quotes, or double quotes")
}
}
return nil
}
func wrapErr(message string, err error) error {
if err == nil {
return nil
}
return fmt.Errorf(message, err)
}

View File

@ -1,204 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package cockroachdb
import (
"database/sql"
"fmt"
"sync"
"time"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/physical"
)
const (
// The lock TTL matches the default that Consul API uses, 15 seconds.
// Used as part of SQL commands to set/extend lock expiry time relative to
// database clock.
CockroachDBLockTTLSeconds = 15
// The amount of time to wait between the lock renewals
CockroachDBLockRenewInterval = 5 * time.Second
// CockroachDBLockRetryInterval is the amount of time to wait
// if a lock fails before trying again.
CockroachDBLockRetryInterval = time.Second
)
// Verify backend satisfies the correct interfaces.
var (
_ physical.HABackend = (*CockroachDBBackend)(nil)
_ physical.Lock = (*CockroachDBLock)(nil)
)
type CockroachDBLock struct {
backend *CockroachDBBackend
key string
value string
identity string
lock sync.Mutex
renewTicker *time.Ticker
// ttlSeconds is how long a lock is valid for.
ttlSeconds int
// renewInterval is how much time to wait between lock renewals. must be << ttl.
renewInterval time.Duration
// retryInterval is how much time to wait between attempts to grab the lock.
retryInterval time.Duration
}
func (c *CockroachDBBackend) HAEnabled() bool {
return c.haEnabled
}
func (c *CockroachDBBackend) LockWith(key, value string) (physical.Lock, error) {
identity, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
return &CockroachDBLock{
backend: c,
key: key,
value: value,
identity: identity,
ttlSeconds: CockroachDBLockTTLSeconds,
renewInterval: CockroachDBLockRenewInterval,
retryInterval: CockroachDBLockRetryInterval,
}, nil
}
// Lock tries to acquire the lock by repeatedly trying to create a record in the
// CockroachDB table. It will block until either the stop channel is closed or
// the lock could be acquired successfully. The returned channel will be closed
// once the lock in the CockroachDB table cannot be renewed, either due to an
// error speaking to CockroachDB or because someone else has taken it.
func (l *CockroachDBLock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) {
l.lock.Lock()
defer l.lock.Unlock()
var (
success = make(chan struct{})
errors = make(chan error, 1)
leader = make(chan struct{})
)
go l.tryToLock(stopCh, success, errors)
select {
case <-success:
// After acquiring it successfully, we must renew the lock periodically.
l.renewTicker = time.NewTicker(l.renewInterval)
go l.periodicallyRenewLock(leader)
case err := <-errors:
return nil, err
case <-stopCh:
return nil, nil
}
return leader, nil
}
// Unlock releases the lock by deleting the lock record from the
// CockroachDB table.
func (l *CockroachDBLock) Unlock() error {
c := l.backend
c.permitPool.Acquire()
defer c.permitPool.Release()
if l.renewTicker != nil {
l.renewTicker.Stop()
}
_, err := c.haStatements["delete"].Exec(l.key)
return err
}
// Value checks whether or not the lock is held by any instance of CockroachDBLock,
// including this one, and returns the current value.
func (l *CockroachDBLock) Value() (bool, string, error) {
c := l.backend
c.permitPool.Acquire()
defer c.permitPool.Release()
var result string
err := c.haStatements["get"].QueryRow(l.key).Scan(&result)
switch err {
case nil:
return true, result, nil
case sql.ErrNoRows:
return false, "", nil
default:
return false, "", err
}
}
// tryToLock tries to create a new item in CockroachDB every `retryInterval`.
// As long as the item cannot be created (because it already exists), it will
// be retried. If the operation fails due to an error, it is sent to the errors
// channel. When the lock could be acquired successfully, the success channel
// is closed.
func (l *CockroachDBLock) tryToLock(stop <-chan struct{}, success chan struct{}, errors chan error) {
ticker := time.NewTicker(l.retryInterval)
defer ticker.Stop()
for {
select {
case <-stop:
return
case <-ticker.C:
gotlock, err := l.writeItem()
switch {
case err != nil:
// Send to the error channel and don't block if full.
select {
case errors <- err:
default:
}
return
case gotlock:
close(success)
return
}
}
}
}
func (l *CockroachDBLock) periodicallyRenewLock(done chan struct{}) {
for range l.renewTicker.C {
gotlock, err := l.writeItem()
if err != nil || !gotlock {
close(done)
l.renewTicker.Stop()
return
}
}
}
// Attempts to put/update the CockroachDB item using condition expressions to
// evaluate the TTL. Returns true if the lock was obtained, false if not.
// If false error may be nil or non-nil: nil indicates simply that someone
// else has the lock, whereas non-nil means that something unexpected happened.
func (l *CockroachDBLock) writeItem() (bool, error) {
c := l.backend
c.permitPool.Acquire()
defer c.permitPool.Release()
sqlResult, err := c.haStatements["upsert"].Exec(l.identity, l.key, l.value, fmt.Sprintf("%d seconds", l.ttlSeconds))
if err != nil {
return false, err
}
if sqlResult == nil {
return false, fmt.Errorf("empty SQL response received")
}
ar, err := sqlResult.RowsAffected()
if err != nil {
return false, err
}
return ar == 1, nil
}

View File

@ -1,214 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package cockroachdb
import (
"context"
"database/sql"
"fmt"
"net/url"
"os"
"testing"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/docker"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/physical"
)
type Config struct {
docker.ServiceURL
TableName string
HATableName string
}
var _ docker.ServiceConfig = &Config{}
func prepareCockroachDBTestContainer(t *testing.T) (func(), *Config) {
if retURL := os.Getenv("CR_URL"); retURL != "" {
s, err := docker.NewServiceURLParse(retURL)
if err != nil {
t.Fatal(err)
}
return func() {}, &Config{
ServiceURL: *s,
TableName: "vault." + defaultTableName,
HATableName: "vault." + defaultHATableName,
}
}
runner, err := docker.NewServiceRunner(docker.RunOptions{
ImageRepo: "docker.mirror.hashicorp.services/cockroachdb/cockroach",
ImageTag: "release-1.0",
ContainerName: "cockroachdb",
Cmd: []string{"start", "--insecure"},
Ports: []string{"26257/tcp"},
})
if err != nil {
t.Fatalf("Could not start docker CockroachDB: %s", err)
}
svc, err := runner.StartService(context.Background(), connectCockroachDB)
if err != nil {
t.Fatalf("Could not start docker CockroachDB: %s", err)
}
return svc.Cleanup, svc.Config.(*Config)
}
func connectCockroachDB(ctx context.Context, host string, port int) (docker.ServiceConfig, error) {
u := url.URL{
Scheme: "postgresql",
User: url.UserPassword("root", ""),
Host: fmt.Sprintf("%s:%d", host, port),
RawQuery: "sslmode=disable",
}
db, err := sql.Open("pgx", u.String())
if err != nil {
return nil, err
}
defer db.Close()
database := "vault"
_, err = db.Exec(fmt.Sprintf("CREATE DATABASE %s", database))
if err != nil {
return nil, err
}
return &Config{
ServiceURL: *docker.NewServiceURL(u),
TableName: database + "." + defaultTableName,
HATableName: database + "." + defaultHATableName,
}, nil
}
func TestCockroachDBBackend(t *testing.T) {
cleanup, config := prepareCockroachDBTestContainer(t)
defer cleanup()
hae := os.Getenv("CR_HA_ENABLED")
if hae == "" {
hae = "true"
}
// Run vault tests
logger := logging.NewVaultLogger(log.Debug)
b1, err := NewCockroachDBBackend(map[string]string{
"connection_url": config.URL().String(),
"table": config.TableName,
"ha_table": config.HATableName,
"ha_enabled": hae,
}, logger)
if err != nil {
t.Fatalf("Failed to create new backend: %v", err)
}
b2, err := NewCockroachDBBackend(map[string]string{
"connection_url": config.URL().String(),
"table": config.TableName,
"ha_table": config.HATableName,
"ha_enabled": hae,
}, logger)
if err != nil {
t.Fatalf("Failed to create new backend: %v", err)
}
defer func() {
truncate(t, b1)
truncate(t, b2)
}()
physical.ExerciseBackend(t, b1)
truncate(t, b1)
physical.ExerciseBackend_ListPrefix(t, b1)
truncate(t, b1)
physical.ExerciseTransactionalBackend(t, b1)
truncate(t, b1)
ha1, ok1 := b1.(physical.HABackend)
ha2, ok2 := b2.(physical.HABackend)
if !ok1 || !ok2 {
t.Fatalf("CockroachDB does not implement HABackend")
}
if ha1.HAEnabled() && ha2.HAEnabled() {
logger.Info("Running ha backend tests")
physical.ExerciseHABackend(t, ha1, ha2)
}
}
func truncate(t *testing.T, b physical.Backend) {
crdb := b.(*CockroachDBBackend)
_, err := crdb.client.Exec("TRUNCATE TABLE " + crdb.table)
if err != nil {
t.Fatalf("Failed to drop table: %v", err)
}
if crdb.haEnabled {
_, err = crdb.client.Exec("TRUNCATE TABLE " + crdb.haTable)
if err != nil {
t.Fatalf("Failed to drop table: %v", err)
}
}
}
func TestValidateDBTable(t *testing.T) {
type testCase struct {
table string
expectErr bool
}
tests := map[string]testCase{
"first character is letter": {"abcdef", false},
"first character is underscore": {"_bcdef", false},
"exclamation point": {"ab!def", true},
"at symbol": {"ab@def", true},
"hash": {"ab#def", true},
"percent": {"ab%def", true},
"carrot": {"ab^def", true},
"ampersand": {"ab&def", true},
"star": {"ab*def", true},
"left paren": {"ab(def", true},
"right paren": {"ab)def", true},
"dash": {"ab-def", true},
"digit": {"a123ef", false},
"dollar end": {"abcde$", false},
"dollar middle": {"ab$def", false},
"dollar start": {"$bcdef", true},
"backtick prefix": {"`bcdef", true},
"backtick middle": {"ab`def", true},
"backtick suffix": {"abcde`", true},
"single quote prefix": {"'bcdef", true},
"single quote middle": {"ab'def", true},
"single quote suffix": {"abcde'", true},
"double quote prefix": {`"bcdef`, true},
"double quote middle": {`ab"def`, true},
"double quote suffix": {`abcde"`, true},
"underscore with all runes": {"_bcd123__a__$", false},
"all runes": {"abcd123__a__$", false},
"default table name": {defaultTableName, false},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
err := validateDBTable(test.table)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
})
t.Run(fmt.Sprintf("database: %s", name), func(t *testing.T) {
dbTable := fmt.Sprintf("%s.%s", test.table, test.table)
err := validateDBTable(dbTable)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
})
}
}

View File

@ -1,441 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package cockroachdb
// sqlKeywords is a reference of all of the keywords that we do not allow for use as the table name
// Referenced from:
// https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#identifiers
// -> https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords
// -> https://www.cockroachlabs.com/docs/stable/sql-grammar.html
var sqlKeywords = map[string]bool{
// reserved_keyword
// https://www.cockroachlabs.com/docs/stable/sql-grammar.html#reserved_keyword
"ALL": true,
"ANALYSE": true,
"ANALYZE": true,
"AND": true,
"ANY": true,
"ARRAY": true,
"AS": true,
"ASC": true,
"ASYMMETRIC": true,
"BOTH": true,
"CASE": true,
"CAST": true,
"CHECK": true,
"COLLATE": true,
"COLUMN": true,
"CONCURRENTLY": true,
"CONSTRAINT": true,
"CREATE": true,
"CURRENT_CATALOG": true,
"CURRENT_DATE": true,
"CURRENT_ROLE": true,
"CURRENT_SCHEMA": true,
"CURRENT_TIME": true,
"CURRENT_TIMESTAMP": true,
"CURRENT_USER": true,
"DEFAULT": true,
"DEFERRABLE": true,
"DESC": true,
"DISTINCT": true,
"DO": true,
"ELSE": true,
"END": true,
"EXCEPT": true,
"FALSE": true,
"FETCH": true,
"FOR": true,
"FOREIGN": true,
"FROM": true,
"GRANT": true,
"GROUP": true,
"HAVING": true,
"IN": true,
"INITIALLY": true,
"INTERSECT": true,
"INTO": true,
"LATERAL": true,
"LEADING": true,
"LIMIT": true,
"LOCALTIME": true,
"LOCALTIMESTAMP": true,
"NOT": true,
"NULL": true,
"OFFSET": true,
"ON": true,
"ONLY": true,
"OR": true,
"ORDER": true,
"PLACING": true,
"PRIMARY": true,
"REFERENCES": true,
"RETURNING": true,
"SELECT": true,
"SESSION_USER": true,
"SOME": true,
"SYMMETRIC": true,
"TABLE": true,
"THEN": true,
"TO": true,
"TRAILING": true,
"TRUE": true,
"UNION": true,
"UNIQUE": true,
"USER": true,
"USING": true,
"VARIADIC": true,
"WHEN": true,
"WHERE": true,
"WINDOW": true,
"WITH": true,
// cockroachdb_extra_reserved_keyword
// https://www.cockroachlabs.com/docs/stable/sql-grammar.html#cockroachdb_extra_reserved_keyword
"INDEX": true,
"NOTHING": true,
// type_func_name_keyword
// https://www.cockroachlabs.com/docs/stable/sql-grammar.html#type_func_name_keyword
"COLLATION": true,
"CROSS": true,
"FULL": true,
"INNER": true,
"ILIKE": true,
"IS": true,
"ISNULL": true,
"JOIN": true,
"LEFT": true,
"LIKE": true,
"NATURAL": true,
"NONE": true,
"NOTNULL": true,
"OUTER": true,
"OVERLAPS": true,
"RIGHT": true,
"SIMILAR": true,
"FAMILY": true,
// col_name_keyword
// https://www.cockroachlabs.com/docs/stable/sql-grammar.html#col_name_keyword
"ANNOTATE_TYPE": true,
"BETWEEN": true,
"BIGINT": true,
"BIT": true,
"BOOLEAN": true,
"CHAR": true,
"CHARACTER": true,
"CHARACTERISTICS": true,
"COALESCE": true,
"DEC": true,
"DECIMAL": true,
"EXISTS": true,
"EXTRACT": true,
"EXTRACT_DURATION": true,
"FLOAT": true,
"GREATEST": true,
"GROUPING": true,
"IF": true,
"IFERROR": true,
"IFNULL": true,
"INT": true,
"INTEGER": true,
"INTERVAL": true,
"ISERROR": true,
"LEAST": true,
"NULLIF": true,
"NUMERIC": true,
"OUT": true,
"OVERLAY": true,
"POSITION": true,
"PRECISION": true,
"REAL": true,
"ROW": true,
"SMALLINT": true,
"SUBSTRING": true,
"TIME": true,
"TIMETZ": true,
"TIMESTAMP": true,
"TIMESTAMPTZ": true,
"TREAT": true,
"TRIM": true,
"VALUES": true,
"VARBIT": true,
"VARCHAR": true,
"VIRTUAL": true,
"WORK": true,
// unreserved_keyword
// https://www.cockroachlabs.com/docs/stable/sql-grammar.html#unreserved_keyword
"ABORT": true,
"ACTION": true,
"ADD": true,
"ADMIN": true,
"AGGREGATE": true,
"ALTER": true,
"AT": true,
"AUTOMATIC": true,
"AUTHORIZATION": true,
"BACKUP": true,
"BEGIN": true,
"BIGSERIAL": true,
"BLOB": true,
"BOOL": true,
"BUCKET_COUNT": true,
"BUNDLE": true,
"BY": true,
"BYTEA": true,
"BYTES": true,
"CACHE": true,
"CANCEL": true,
"CASCADE": true,
"CHANGEFEED": true,
"CLUSTER": true,
"COLUMNS": true,
"COMMENT": true,
"COMMIT": true,
"COMMITTED": true,
"COMPACT": true,
"COMPLETE": true,
"CONFLICT": true,
"CONFIGURATION": true,
"CONFIGURATIONS": true,
"CONFIGURE": true,
"CONSTRAINTS": true,
"CONVERSION": true,
"COPY": true,
"COVERING": true,
"CREATEROLE": true,
"CUBE": true,
"CURRENT": true,
"CYCLE": true,
"DATA": true,
"DATABASE": true,
"DATABASES": true,
"DATE": true,
"DAY": true,
"DEALLOCATE": true,
"DELETE": true,
"DEFERRED": true,
"DISCARD": true,
"DOMAIN": true,
"DOUBLE": true,
"DROP": true,
"ENCODING": true,
"ENUM": true,
"ESCAPE": true,
"EXCLUDE": true,
"EXECUTE": true,
"EXPERIMENTAL": true,
"EXPERIMENTAL_AUDIT": true,
"EXPERIMENTAL_FINGERPRINTS": true,
"EXPERIMENTAL_RELOCATE": true,
"EXPERIMENTAL_REPLICA": true,
"EXPIRATION": true,
"EXPLAIN": true,
"EXPORT": true,
"EXTENSION": true,
"FILES": true,
"FILTER": true,
"FIRST": true,
"FLOAT4": true,
"FLOAT8": true,
"FOLLOWING": true,
"FORCE_INDEX": true,
"FUNCTION": true,
"GLOBAL": true,
"GRANTS": true,
"GROUPS": true,
"HASH": true,
"HIGH": true,
"HISTOGRAM": true,
"HOUR": true,
"IMMEDIATE": true,
"IMPORT": true,
"INCLUDE": true,
"INCREMENT": true,
"INCREMENTAL": true,
"INDEXES": true,
"INET": true,
"INJECT": true,
"INSERT": true,
"INT2": true,
"INT2VECTOR": true,
"INT4": true,
"INT8": true,
"INT64": true,
"INTERLEAVE": true,
"INVERTED": true,
"ISOLATION": true,
"JOB": true,
"JOBS": true,
"JSON": true,
"JSONB": true,
"KEY": true,
"KEYS": true,
"KV": true,
"LANGUAGE": true,
"LAST": true,
"LC_COLLATE": true,
"LC_CTYPE": true,
"LEASE": true,
"LESS": true,
"LEVEL": true,
"LIST": true,
"LOCAL": true,
"LOCKED": true,
"LOGIN": true,
"LOOKUP": true,
"LOW": true,
"MATCH": true,
"MATERIALIZED": true,
"MAXVALUE": true,
"MERGE": true,
"MINUTE": true,
"MINVALUE": true,
"MONTH": true,
"NAMES": true,
"NAN": true,
"NAME": true,
"NEXT": true,
"NO": true,
"NORMAL": true,
"NO_INDEX_JOIN": true,
"NOCREATEROLE": true,
"NOLOGIN": true,
"NOWAIT": true,
"NULLS": true,
"IGNORE_FOREIGN_KEYS": true,
"OF": true,
"OFF": true,
"OID": true,
"OIDS": true,
"OIDVECTOR": true,
"OPERATOR": true,
"OPT": true,
"OPTION": true,
"OPTIONS": true,
"ORDINALITY": true,
"OTHERS": true,
"OVER": true,
"OWNED": true,
"PARENT": true,
"PARTIAL": true,
"PARTITION": true,
"PARTITIONS": true,
"PASSWORD": true,
"PAUSE": true,
"PHYSICAL": true,
"PLAN": true,
"PLANS": true,
"PRECEDING": true,
"PREPARE": true,
"PRESERVE": true,
"PRIORITY": true,
"PUBLIC": true,
"PUBLICATION": true,
"QUERIES": true,
"QUERY": true,
"RANGE": true,
"RANGES": true,
"READ": true,
"RECURSIVE": true,
"REF": true,
"REGCLASS": true,
"REGPROC": true,
"REGPROCEDURE": true,
"REGNAMESPACE": true,
"REGTYPE": true,
"REINDEX": true,
"RELEASE": true,
"RENAME": true,
"REPEATABLE": true,
"REPLACE": true,
"RESET": true,
"RESTORE": true,
"RESTRICT": true,
"RESUME": true,
"REVOKE": true,
"ROLE": true,
"ROLES": true,
"ROLLBACK": true,
"ROLLUP": true,
"ROWS": true,
"RULE": true,
"SETTING": true,
"SETTINGS": true,
"STATUS": true,
"SAVEPOINT": true,
"SCATTER": true,
"SCHEMA": true,
"SCHEMAS": true,
"SCRUB": true,
"SEARCH": true,
"SECOND": true,
"SERIAL": true,
"SERIALIZABLE": true,
"SERIAL2": true,
"SERIAL4": true,
"SERIAL8": true,
"SEQUENCE": true,
"SEQUENCES": true,
"SERVER": true,
"SESSION": true,
"SESSIONS": true,
"SET": true,
"SHARE": true,
"SHOW": true,
"SIMPLE": true,
"SKIP": true,
"SMALLSERIAL": true,
"SNAPSHOT": true,
"SPLIT": true,
"SQL": true,
"START": true,
"STATISTICS": true,
"STDIN": true,
"STORE": true,
"STORED": true,
"STORING": true,
"STRICT": true,
"STRING": true,
"SUBSCRIPTION": true,
"SYNTAX": true,
"SYSTEM": true,
"TABLES": true,
"TEMP": true,
"TEMPLATE": true,
"TEMPORARY": true,
"TESTING_RELOCATE": true,
"TEXT": true,
"TIES": true,
"TRACE": true,
"TRANSACTION": true,
"TRIGGER": true,
"TRUNCATE": true,
"TRUSTED": true,
"TYPE": true,
"THROTTLING": true,
"UNBOUNDED": true,
"UNCOMMITTED": true,
"UNKNOWN": true,
"UNLOGGED": true,
"UNSPLIT": true,
"UNTIL": true,
"UPDATE": true,
"UPSERT": true,
"UUID": true,
"USE": true,
"USERS": true,
"VALID": true,
"VALIDATE": true,
"VALUE": true,
"VARYING": true,
"VIEW": true,
"WITHIN": true,
"WITHOUT": true,
"WRITE": true,
"YEAR": true,
"ZONE": true,
}

View File

@ -1,779 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package mysql
import (
"context"
"crypto/tls"
"crypto/x509"
"database/sql"
"errors"
"fmt"
"io/ioutil"
"math"
"net/url"
"sort"
"strconv"
"strings"
"sync"
"time"
"unicode"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-multierror"
metrics "github.com/armon/go-metrics"
mysql "github.com/go-sql-driver/mysql"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/sdk/physical"
)
// Verify MySQLBackend satisfies the correct interfaces
var (
_ physical.Backend = (*MySQLBackend)(nil)
_ physical.HABackend = (*MySQLBackend)(nil)
_ physical.Lock = (*MySQLHALock)(nil)
)
// Unreserved tls key
// Reserved values are "true", "false", "skip-verify"
const mysqlTLSKey = "default"
// MySQLBackend is a physical backend that stores data
// within MySQL database.
type MySQLBackend struct {
dbTable string
dbLockTable string
client *sql.DB
statements map[string]*sql.Stmt
logger log.Logger
permitPool *physical.PermitPool
conf map[string]string
redirectHost string
redirectPort int64
haEnabled bool
}
// NewMySQLBackend constructs a MySQL backend using the given API client and
// server address and credential for accessing mysql database.
func NewMySQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
var err error
db, err := NewMySQLClient(conf, logger)
if err != nil {
return nil, err
}
database := conf["database"]
if database == "" {
database = "vault"
}
table := conf["table"]
if table == "" {
table = "vault"
}
err = validateDBTable(database, table)
if err != nil {
return nil, err
}
dbTable := fmt.Sprintf("`%s`.`%s`", database, table)
maxParStr, ok := conf["max_parallel"]
var maxParInt int
if ok {
maxParInt, err = strconv.Atoi(maxParStr)
if err != nil {
return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err)
}
if logger.IsDebug() {
logger.Debug("max_parallel set", "max_parallel", maxParInt)
}
} else {
maxParInt = physical.DefaultParallelOperations
}
// Check schema exists
var schemaExist bool
schemaRows, err := db.Query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?", database)
if err != nil {
return nil, fmt.Errorf("failed to check mysql schema exist: %w", err)
}
defer schemaRows.Close()
schemaExist = schemaRows.Next()
// Check table exists
var tableExist bool
tableRows, err := db.Query("SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_NAME = ? AND TABLE_SCHEMA = ?", table, database)
if err != nil {
return nil, fmt.Errorf("failed to check mysql table exist: %w", err)
}
defer tableRows.Close()
tableExist = tableRows.Next()
// Create the required database if it doesn't exists.
if !schemaExist {
if _, err := db.Exec("CREATE DATABASE IF NOT EXISTS `" + database + "`"); err != nil {
return nil, fmt.Errorf("failed to create mysql database: %w", err)
}
}
// Create the required table if it doesn't exists.
if !tableExist {
create_query := "CREATE TABLE IF NOT EXISTS " + dbTable +
" (vault_key varbinary(3072), vault_value mediumblob, PRIMARY KEY (vault_key))"
if _, err := db.Exec(create_query); err != nil {
return nil, fmt.Errorf("failed to create mysql table: %w", err)
}
}
// Default value for ha_enabled
haEnabledStr, ok := conf["ha_enabled"]
if !ok {
haEnabledStr = "false"
}
haEnabled, err := strconv.ParseBool(haEnabledStr)
if err != nil {
return nil, fmt.Errorf("value [%v] of 'ha_enabled' could not be understood", haEnabledStr)
}
locktable, ok := conf["lock_table"]
if !ok {
locktable = table + "_lock"
}
dbLockTable := "`" + database + "`.`" + locktable + "`"
// Only create lock table if ha_enabled is true
if haEnabled {
// Check table exists
var lockTableExist bool
lockTableRows, err := db.Query("SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_NAME = ? AND TABLE_SCHEMA = ?", locktable, database)
if err != nil {
return nil, fmt.Errorf("failed to check mysql table exist: %w", err)
}
defer lockTableRows.Close()
lockTableExist = lockTableRows.Next()
// Create the required table if it doesn't exists.
if !lockTableExist {
create_query := "CREATE TABLE IF NOT EXISTS " + dbLockTable +
" (node_job varbinary(512), current_leader varbinary(512), PRIMARY KEY (node_job))"
if _, err := db.Exec(create_query); err != nil {
return nil, fmt.Errorf("failed to create mysql table: %w", err)
}
}
}
// Setup the backend.
m := &MySQLBackend{
dbTable: dbTable,
dbLockTable: dbLockTable,
client: db,
statements: make(map[string]*sql.Stmt),
logger: logger,
permitPool: physical.NewPermitPool(maxParInt),
conf: conf,
haEnabled: haEnabled,
}
// Prepare all the statements required
statements := map[string]string{
"put": "INSERT INTO " + dbTable +
" VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)",
"get": "SELECT vault_value FROM " + dbTable + " WHERE vault_key = ?",
"delete": "DELETE FROM " + dbTable + " WHERE vault_key = ?",
"list": "SELECT vault_key FROM " + dbTable + " WHERE vault_key LIKE ?",
}
// Only prepare ha-related statements if we need them
if haEnabled {
statements["get_lock"] = "SELECT current_leader FROM " + dbLockTable + " WHERE node_job = ?"
statements["used_lock"] = "SELECT IS_USED_LOCK(?)"
}
for name, query := range statements {
if err := m.prepare(name, query); err != nil {
return nil, err
}
}
return m, nil
}
// validateDBTable to prevent SQL injection attacks. This ensures that the database and table names only have valid
// characters in them. MySQL allows for more characters that this will allow, but there isn't an easy way of
// representing the full Unicode Basic Multilingual Plane to check against.
// https://dev.mysql.com/doc/refman/5.7/en/identifiers.html
func validateDBTable(db, table string) (err error) {
merr := &multierror.Error{}
merr = multierror.Append(merr, wrapErr("invalid database: %w", validate(db)))
merr = multierror.Append(merr, wrapErr("invalid table: %w", validate(table)))
return merr.ErrorOrNil()
}
func validate(name string) (err error) {
if name == "" {
return fmt.Errorf("missing name")
}
// From: https://dev.mysql.com/doc/refman/5.7/en/identifiers.html
// - Permitted characters in quoted identifiers include the full Unicode Basic Multilingual Plane (BMP), except U+0000:
// ASCII: U+0001 .. U+007F
// Extended: U+0080 .. U+FFFF
// - ASCII NUL (U+0000) and supplementary characters (U+10000 and higher) are not permitted in quoted or unquoted identifiers.
// - Identifiers may begin with a digit but unless quoted may not consist solely of digits.
// - Database, table, and column names cannot end with space characters.
//
// We are explicitly excluding all space characters (it's easier to deal with)
// The name will be quoted, so the all-digit requirement doesn't apply
runes := []rune(name)
validationErr := fmt.Errorf("invalid character found: can only include printable, non-space characters between [0x0001-0xFFFF]")
for _, r := range runes {
// U+0000 Explicitly disallowed
if r == 0x0000 {
return fmt.Errorf("invalid character: cannot include 0x0000")
}
// Cannot be above 0xFFFF
if r > 0xFFFF {
return fmt.Errorf("invalid character: cannot include any characters above 0xFFFF")
}
if r == '`' {
return fmt.Errorf("invalid character: cannot include '`' character")
}
if r == '\'' || r == '"' {
return fmt.Errorf("invalid character: cannot include quotes")
}
// We are excluding non-printable characters (not mentioned in the docs)
if !unicode.IsPrint(r) {
return validationErr
}
// We are excluding space characters (not mentioned in the docs)
if unicode.IsSpace(r) {
return validationErr
}
}
return nil
}
func wrapErr(message string, err error) error {
if err == nil {
return nil
}
return fmt.Errorf(message, err)
}
func NewMySQLClient(conf map[string]string, logger log.Logger) (*sql.DB, error) {
var err error
// Get the MySQL credentials to perform read/write operations.
username, ok := conf["username"]
if !ok || username == "" {
return nil, fmt.Errorf("missing username")
}
password, ok := conf["password"]
if !ok || password == "" {
return nil, fmt.Errorf("missing password")
}
// Get or set MySQL server address. Defaults to localhost and default port(3306)
address, ok := conf["address"]
if !ok {
address = "127.0.0.1:3306"
}
maxIdleConnStr, ok := conf["max_idle_connections"]
var maxIdleConnInt int
if ok {
maxIdleConnInt, err = strconv.Atoi(maxIdleConnStr)
if err != nil {
return nil, fmt.Errorf("failed parsing max_idle_connections parameter: %w", err)
}
if logger.IsDebug() {
logger.Debug("max_idle_connections set", "max_idle_connections", maxIdleConnInt)
}
}
maxConnLifeStr, ok := conf["max_connection_lifetime"]
var maxConnLifeInt int
if ok {
maxConnLifeInt, err = strconv.Atoi(maxConnLifeStr)
if err != nil {
return nil, fmt.Errorf("failed parsing max_connection_lifetime parameter: %w", err)
}
if logger.IsDebug() {
logger.Debug("max_connection_lifetime set", "max_connection_lifetime", maxConnLifeInt)
}
}
maxParStr, ok := conf["max_parallel"]
var maxParInt int
if ok {
maxParInt, err = strconv.Atoi(maxParStr)
if err != nil {
return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err)
}
if logger.IsDebug() {
logger.Debug("max_parallel set", "max_parallel", maxParInt)
}
} else {
maxParInt = physical.DefaultParallelOperations
}
dsnParams := url.Values{}
tlsCaFile, tlsOk := conf["tls_ca_file"]
if tlsOk {
if err := setupMySQLTLSConfig(tlsCaFile); err != nil {
return nil, fmt.Errorf("failed register TLS config: %w", err)
}
dsnParams.Add("tls", mysqlTLSKey)
}
ptAllowed, ptOk := conf["plaintext_connection_allowed"]
if !(ptOk && strings.ToLower(ptAllowed) == "true") && !tlsOk {
logger.Warn("No TLS specified, credentials will be sent in plaintext. To mute this warning add 'plaintext_connection_allowed' with a true value to your MySQL configuration in your config file.")
}
// Create MySQL handle for the database.
dsn := username + ":" + password + "@tcp(" + address + ")/?" + dsnParams.Encode()
db, err := sql.Open("mysql", dsn)
if err != nil {
return nil, fmt.Errorf("failed to connect to mysql: %w", err)
}
db.SetMaxOpenConns(maxParInt)
if maxIdleConnInt != 0 {
db.SetMaxIdleConns(maxIdleConnInt)
}
if maxConnLifeInt != 0 {
db.SetConnMaxLifetime(time.Duration(maxConnLifeInt) * time.Second)
}
return db, err
}
// prepare is a helper to prepare a query for future execution
func (m *MySQLBackend) prepare(name, query string) error {
stmt, err := m.client.Prepare(query)
if err != nil {
return fmt.Errorf("failed to prepare %q: %w", name, err)
}
m.statements[name] = stmt
return nil
}
// Put is used to insert or update an entry.
func (m *MySQLBackend) Put(ctx context.Context, entry *physical.Entry) error {
defer metrics.MeasureSince([]string{"mysql", "put"}, time.Now())
m.permitPool.Acquire()
defer m.permitPool.Release()
_, err := m.statements["put"].Exec(entry.Key, entry.Value)
if err != nil {
return err
}
return nil
}
// Get is used to fetch an entry.
func (m *MySQLBackend) Get(ctx context.Context, key string) (*physical.Entry, error) {
defer metrics.MeasureSince([]string{"mysql", "get"}, time.Now())
m.permitPool.Acquire()
defer m.permitPool.Release()
var result []byte
err := m.statements["get"].QueryRow(key).Scan(&result)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
ent := &physical.Entry{
Key: key,
Value: result,
}
return ent, nil
}
// Delete is used to permanently delete an entry
func (m *MySQLBackend) Delete(ctx context.Context, key string) error {
defer metrics.MeasureSince([]string{"mysql", "delete"}, time.Now())
m.permitPool.Acquire()
defer m.permitPool.Release()
_, err := m.statements["delete"].Exec(key)
if err != nil {
return err
}
return nil
}
// List is used to list all the keys under a given
// prefix, up to the next prefix.
func (m *MySQLBackend) List(ctx context.Context, prefix string) ([]string, error) {
defer metrics.MeasureSince([]string{"mysql", "list"}, time.Now())
m.permitPool.Acquire()
defer m.permitPool.Release()
// Add the % wildcard to the prefix to do the prefix search
likePrefix := prefix + "%"
rows, err := m.statements["list"].Query(likePrefix)
if err != nil {
return nil, fmt.Errorf("failed to execute statement: %w", err)
}
var keys []string
for rows.Next() {
var key string
err = rows.Scan(&key)
if err != nil {
return nil, fmt.Errorf("failed to scan rows: %w", err)
}
key = strings.TrimPrefix(key, prefix)
if i := strings.Index(key, "/"); i == -1 {
// Add objects only from the current 'folder'
keys = append(keys, key)
} else if i != -1 {
// Add truncated 'folder' paths
keys = strutil.AppendIfMissing(keys, string(key[:i+1]))
}
}
sort.Strings(keys)
return keys, nil
}
// LockWith is used for mutual exclusion based on the given key.
func (m *MySQLBackend) LockWith(key, value string) (physical.Lock, error) {
l := &MySQLHALock{
in: m,
key: key,
value: value,
logger: m.logger,
}
return l, nil
}
func (m *MySQLBackend) HAEnabled() bool {
return m.haEnabled
}
// MySQLHALock is a MySQL Lock implementation for the HABackend
type MySQLHALock struct {
in *MySQLBackend
key string
value string
logger log.Logger
held bool
localLock sync.Mutex
leaderCh chan struct{}
stopCh <-chan struct{}
lock *MySQLLock
}
func (i *MySQLHALock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) {
i.localLock.Lock()
defer i.localLock.Unlock()
if i.held {
return nil, fmt.Errorf("lock already held")
}
// Attempt an async acquisition
didLock := make(chan struct{})
failLock := make(chan error, 1)
releaseCh := make(chan bool, 1)
go i.attemptLock(i.key, i.value, didLock, failLock, releaseCh)
// Wait for lock acquisition, failure, or shutdown
select {
case <-didLock:
releaseCh <- false
case err := <-failLock:
return nil, err
case <-stopCh:
releaseCh <- true
return nil, nil
}
// Create the leader channel
i.held = true
i.leaderCh = make(chan struct{})
go i.monitorLock(i.leaderCh)
i.stopCh = stopCh
return i.leaderCh, nil
}
func (i *MySQLHALock) attemptLock(key, value string, didLock chan struct{}, failLock chan error, releaseCh chan bool) {
lock, err := NewMySQLLock(i.in, i.logger, key, value)
if err != nil {
failLock <- err
return
}
// Set node value
i.lock = lock
err = lock.Lock()
if err != nil {
failLock <- err
return
}
// Signal that lock is held
close(didLock)
// Handle an early abort
release := <-releaseCh
if release {
lock.Unlock()
}
}
func (i *MySQLHALock) monitorLock(leaderCh chan struct{}) {
for {
// The only way to lose this lock is if someone is
// logging into the DB and altering system tables or you lose a connection in
// which case you will lose the lock anyway.
err := i.hasLock(i.key)
if err != nil {
// Somehow we lost the lock.... likely because the connection holding
// the lock was closed or someone was playing around with the locks in the DB.
close(leaderCh)
return
}
time.Sleep(5 * time.Second)
}
}
func (i *MySQLHALock) Unlock() error {
i.localLock.Lock()
defer i.localLock.Unlock()
if !i.held {
return nil
}
err := i.lock.Unlock()
if err == nil {
i.held = false
return nil
}
return err
}
// hasLock will check if a lock is held by checking the current lock id against our known ID.
func (i *MySQLHALock) hasLock(key string) error {
var result sql.NullInt64
err := i.in.statements["used_lock"].QueryRow(key).Scan(&result)
if err == sql.ErrNoRows || !result.Valid {
// This is not an error to us since it just means the lock isn't held
return nil
}
if err != nil {
return err
}
// IS_USED_LOCK will return the ID of the connection that created the lock.
if result.Int64 != GlobalLockID {
return ErrLockHeld
}
return nil
}
func (i *MySQLHALock) GetLeader() (string, error) {
defer metrics.MeasureSince([]string{"mysql", "lock_get"}, time.Now())
var result string
err := i.in.statements["get_lock"].QueryRow("leader").Scan(&result)
if err == sql.ErrNoRows {
return "", err
}
return result, nil
}
func (i *MySQLHALock) Value() (bool, string, error) {
leaderkey, err := i.GetLeader()
if err != nil {
return false, "", err
}
return true, leaderkey, err
}
// MySQLLock provides an easy way to grab and release mysql
// locks using the built in GET_LOCK function. Note that these
// locks are released when you lose connection to the server.
type MySQLLock struct {
parentConn *MySQLBackend
in *sql.DB
logger log.Logger
statements map[string]*sql.Stmt
key string
value string
}
// Errors specific to trying to grab a lock in MySQL
var (
// This is the GlobalLockID for checking if the lock we got is still the current lock
GlobalLockID int64
// ErrLockHeld is returned when another vault instance already has a lock held for the given key.
ErrLockHeld = errors.New("mysql: lock already held")
// ErrUnlockFailed
ErrUnlockFailed = errors.New("mysql: unable to release lock, already released or not held by this session")
// You were unable to update that you are the new leader in the DB
ErrClaimFailed = errors.New("mysql: unable to update DB with new leader information")
// Error to throw if between getting the lock and checking the ID of it we lost it.
ErrSettingGlobalID = errors.New("mysql: getting global lock id failed")
)
// NewMySQLLock helper function
func NewMySQLLock(in *MySQLBackend, l log.Logger, key, value string) (*MySQLLock, error) {
// Create a new MySQL connection so we can close this and have no effect on
// the rest of the MySQL backend and any cleanup that might need to be done.
conn, _ := NewMySQLClient(in.conf, in.logger)
m := &MySQLLock{
parentConn: in,
in: conn,
logger: l,
statements: make(map[string]*sql.Stmt),
key: key,
value: value,
}
statements := map[string]string{
"put": "INSERT INTO " + in.dbLockTable +
" VALUES( ?, ? ) ON DUPLICATE KEY UPDATE current_leader=VALUES(current_leader)",
}
for name, query := range statements {
if err := m.prepare(name, query); err != nil {
return nil, err
}
}
return m, nil
}
// prepare is a helper to prepare a query for future execution
func (m *MySQLLock) prepare(name, query string) error {
stmt, err := m.in.Prepare(query)
if err != nil {
return fmt.Errorf("failed to prepare %q: %w", name, err)
}
m.statements[name] = stmt
return nil
}
// update the current cluster leader in the DB. This is used so
// we can tell the servers in standby who the active leader is.
func (i *MySQLLock) becomeLeader() error {
_, err := i.statements["put"].Exec("leader", i.value)
if err != nil {
return err
}
return nil
}
// Lock will try to get a lock for an indefinite amount of time
// based on the given key that has been requested.
func (i *MySQLLock) Lock() error {
defer metrics.MeasureSince([]string{"mysql", "get_lock"}, time.Now())
// Lock timeout math.MaxInt32 instead of -1 solves compatibility issues with
// different MySQL flavours i.e. MariaDB
rows, err := i.in.Query("SELECT GET_LOCK(?, ?), IS_USED_LOCK(?)", i.key, math.MaxInt32, i.key)
if err != nil {
return err
}
defer rows.Close()
rows.Next()
var lock sql.NullInt64
var connectionID sql.NullInt64
rows.Scan(&lock, &connectionID)
if rows.Err() != nil {
return rows.Err()
}
// 1 is returned from GET_LOCK if it was able to get the lock
// 0 if it failed and NULL if some strange error happened.
// https://dev.mysql.com/doc/refman/8.0/en/miscellaneous-functions.html#function_get-lock
if !lock.Valid || lock.Int64 != 1 {
return ErrLockHeld
}
// Since we have the lock alert the rest of the cluster
// that we are now the active leader.
err = i.becomeLeader()
if err != nil {
return ErrLockHeld
}
// This will return the connection ID of NULL if an error happens
// https://dev.mysql.com/doc/refman/8.0/en/miscellaneous-functions.html#function_is-used-lock
if !connectionID.Valid {
return ErrSettingGlobalID
}
GlobalLockID = connectionID.Int64
return nil
}
// Unlock just closes the connection. This is because closing the MySQL connection
// is a 100% reliable way to close the lock. If you just release the lock you must
// do it from the same mysql connection_id that you originally created it from. This
// is a huge hastle and I actually couldn't find a clean way to do this although one
// likely does exist. Closing the connection however ensures we don't ever get into a
// state where we try to release the lock and it hangs it is also much less code.
func (i *MySQLLock) Unlock() error {
err := i.in.Close()
if err != nil {
return ErrUnlockFailed
}
return nil
}
// Establish a TLS connection with a given CA certificate
// Register a tsl.Config associated with the same key as the dns param from sql.Open
// foo:bar@tcp(127.0.0.1:3306)/dbname?tls=default
func setupMySQLTLSConfig(tlsCaFile string) error {
rootCertPool := x509.NewCertPool()
pem, err := ioutil.ReadFile(tlsCaFile)
if err != nil {
return err
}
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
return err
}
err = mysql.RegisterTLSConfig(mysqlTLSKey, &tls.Config{
RootCAs: rootCertPool,
})
if err != nil {
return err
}
return nil
}

View File

@ -1,346 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package mysql
import (
"bytes"
"os"
"strings"
"testing"
"time"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/physical"
_ "github.com/go-sql-driver/mysql"
mysql "github.com/go-sql-driver/mysql"
mysqlhelper "github.com/hashicorp/vault/helper/testhelpers/mysql"
)
func TestMySQLPlaintextCatch(t *testing.T) {
address := os.Getenv("MYSQL_ADDR")
if address == "" {
t.SkipNow()
}
database := os.Getenv("MYSQL_DB")
if database == "" {
database = "test"
}
table := os.Getenv("MYSQL_TABLE")
if table == "" {
table = "test"
}
username := os.Getenv("MYSQL_USERNAME")
password := os.Getenv("MYSQL_PASSWORD")
// Run vault tests
var buf bytes.Buffer
log.DefaultOutput = &buf
logger := logging.NewVaultLogger(log.Debug)
NewMySQLBackend(map[string]string{
"address": address,
"database": database,
"table": table,
"username": username,
"password": password,
"plaintext_connection_allowed": "false",
}, logger)
str := buf.String()
dataIdx := strings.IndexByte(str, ' ')
rest := str[dataIdx+1:]
if !strings.Contains(rest, "credentials will be sent in plaintext") {
t.Fatalf("No warning of plaintext credentials occurred")
}
}
func TestMySQLBackend(t *testing.T) {
address := os.Getenv("MYSQL_ADDR")
if address == "" {
t.SkipNow()
}
database := os.Getenv("MYSQL_DB")
if database == "" {
database = "test"
}
table := os.Getenv("MYSQL_TABLE")
if table == "" {
table = "test"
}
username := os.Getenv("MYSQL_USERNAME")
password := os.Getenv("MYSQL_PASSWORD")
// Run vault tests
logger := logging.NewVaultLogger(log.Debug)
b, err := NewMySQLBackend(map[string]string{
"address": address,
"database": database,
"table": table,
"username": username,
"password": password,
"plaintext_connection_allowed": "true",
"max_connection_lifetime": "1",
}, logger)
if err != nil {
t.Fatalf("Failed to create new backend: %v", err)
}
defer func() {
mysql := b.(*MySQLBackend)
_, err := mysql.client.Exec("DROP TABLE IF EXISTS " + mysql.dbTable + " ," + mysql.dbLockTable)
if err != nil {
t.Fatalf("Failed to drop table: %v", err)
}
}()
physical.ExerciseBackend(t, b)
physical.ExerciseBackend_ListPrefix(t, b)
}
func TestMySQLHABackend(t *testing.T) {
address := os.Getenv("MYSQL_ADDR")
if address == "" {
t.SkipNow()
}
database := os.Getenv("MYSQL_DB")
if database == "" {
database = "test"
}
table := os.Getenv("MYSQL_TABLE")
if table == "" {
table = "test"
}
username := os.Getenv("MYSQL_USERNAME")
password := os.Getenv("MYSQL_PASSWORD")
// Run vault tests
logger := logging.NewVaultLogger(log.Debug)
config := map[string]string{
"address": address,
"database": database,
"table": table,
"username": username,
"password": password,
"ha_enabled": "true",
"plaintext_connection_allowed": "true",
}
b, err := NewMySQLBackend(config, logger)
if err != nil {
t.Fatalf("Failed to create new backend: %v", err)
}
defer func() {
mysql := b.(*MySQLBackend)
_, err := mysql.client.Exec("DROP TABLE IF EXISTS " + mysql.dbTable + " ," + mysql.dbLockTable)
if err != nil {
t.Fatalf("Failed to drop table: %v", err)
}
}()
b2, err := NewMySQLBackend(config, logger)
if err != nil {
t.Fatalf("Failed to create new backend: %v", err)
}
physical.ExerciseHABackend(t, b.(physical.HABackend), b2.(physical.HABackend))
}
// TestMySQLHABackend_LockFailPanic is a regression test for the panic shown in
// https://github.com/hashicorp/vault/issues/8203 and patched in
// https://github.com/hashicorp/vault/pull/8229
func TestMySQLHABackend_LockFailPanic(t *testing.T) {
cleanup, connURL := mysqlhelper.PrepareTestContainer(t, false, "secret")
cfg, err := mysql.ParseDSN(connURL)
if err != nil {
t.Fatal(err)
}
if err := mysqlhelper.TestCredsExist(t, connURL, cfg.User, cfg.Passwd); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}
table := "test"
logger := logging.NewVaultLogger(log.Debug)
config := map[string]string{
"address": cfg.Addr,
"database": cfg.DBName,
"table": table,
"username": cfg.User,
"password": cfg.Passwd,
"ha_enabled": "true",
"plaintext_connection_allowed": "true",
}
b, err := NewMySQLBackend(config, logger)
if err != nil {
t.Fatalf("Failed to create new backend: %v", err)
}
b2, err := NewMySQLBackend(config, logger)
if err != nil {
t.Fatalf("Failed to create new backend: %v", err)
}
b1ha := b.(physical.HABackend)
b2ha := b2.(physical.HABackend)
// Copied from ExerciseHABackend - ensuring things are normal at this point
// Get the lock
lock, err := b1ha.LockWith("foo", "bar")
if err != nil {
t.Fatalf("initial lock: %v", err)
}
// Attempt to lock
leaderCh, err := lock.Lock(nil)
if err != nil {
t.Fatalf("lock attempt 1: %v", err)
}
if leaderCh == nil {
t.Fatalf("missing leaderCh")
}
// Check the value
held, val, err := lock.Value()
if err != nil {
t.Fatalf("err: %v", err)
}
if !held {
t.Errorf("should be held")
}
if val != "bar" {
t.Errorf("expected value bar: %v", err)
}
// Second acquisition should fail
lock2, err := b2ha.LockWith("foo", "baz")
if err != nil {
t.Fatalf("lock 2: %v", err)
}
stopCh := make(chan struct{})
time.AfterFunc(10*time.Second, func() {
close(stopCh)
})
// Attempt to lock - can't lock because lock1 is held - this is normal
leaderCh2, err := lock2.Lock(stopCh)
if err != nil {
t.Fatalf("stop lock 2: %v", err)
}
if leaderCh2 != nil {
t.Errorf("should not have gotten leaderCh: %v", leaderCh2)
}
// end normal
// Clean up the database. When Lock() is called, a new connection is created
// using the configuration. If that connection cannot be created, there was a
// panic due to not returning with the connection error. Here we intentionally
// break the config for b2, so a new connection can't be made, which would
// trigger the panic shown in https://github.com/hashicorp/vault/issues/8203
cleanup()
stopCh2 := make(chan struct{})
time.AfterFunc(10*time.Second, func() {
close(stopCh2)
})
leaderCh2, err = lock2.Lock(stopCh2)
if err == nil {
t.Fatalf("expected error, got none, leaderCh2=%v", leaderCh2)
}
}
func TestValidateDBTable(t *testing.T) {
type testCase struct {
database string
table string
expectErr bool
}
tests := map[string]testCase{
"empty database & table": {"", "", true},
"empty database": {"", "a", true},
"empty table": {"a", "", true},
"ascii database": {"abcde", "a", false},
"ascii table": {"a", "abcde", false},
"ascii database & table": {"abcde", "abcde", false},
"only whitespace db": {" ", "a", true},
"only whitespace table": {"a", " ", true},
"whitespace prefix db": {" bcde", "a", true},
"whitespace middle db": {"ab de", "a", true},
"whitespace suffix db": {"abcd ", "a", true},
"whitespace prefix table": {"a", " bcde", true},
"whitespace middle table": {"a", "ab de", true},
"whitespace suffix table": {"a", "abcd ", true},
"backtick prefix db": {"`bcde", "a", true},
"backtick middle db": {"ab`de", "a", true},
"backtick suffix db": {"abcd`", "a", true},
"backtick prefix table": {"a", "`bcde", true},
"backtick middle table": {"a", "ab`de", true},
"backtick suffix table": {"a", "abcd`", true},
"single quote prefix db": {"'bcde", "a", true},
"single quote middle db": {"ab'de", "a", true},
"single quote suffix db": {"abcd'", "a", true},
"single quote prefix table": {"a", "'bcde", true},
"single quote middle table": {"a", "ab'de", true},
"single quote suffix table": {"a", "abcd'", true},
"double quote prefix db": {`"bcde`, "a", true},
"double quote middle db": {`ab"de`, "a", true},
"double quote suffix db": {`abcd"`, "a", true},
"double quote prefix table": {"a", `"bcde`, true},
"double quote middle table": {"a", `ab"de`, true},
"double quote suffix table": {"a", `abcd"`, true},
"0x0000 prefix db": {str(0x0000, 'b', 'c'), "a", true},
"0x0000 middle db": {str('a', 0x0000, 'c'), "a", true},
"0x0000 suffix db": {str('a', 'b', 0x0000), "a", true},
"0x0000 prefix table": {"a", str(0x0000, 'b', 'c'), true},
"0x0000 middle table": {"a", str('a', 0x0000, 'c'), true},
"0x0000 suffix table": {"a", str('a', 'b', 0x0000), true},
"unicode > 0xFFFF prefix db": {str(0x10000, 'b', 'c'), "a", true},
"unicode > 0xFFFF middle db": {str('a', 0x10000, 'c'), "a", true},
"unicode > 0xFFFF suffix db": {str('a', 'b', 0x10000), "a", true},
"unicode > 0xFFFF prefix table": {"a", str(0x10000, 'b', 'c'), true},
"unicode > 0xFFFF middle table": {"a", str('a', 0x10000, 'c'), true},
"unicode > 0xFFFF suffix table": {"a", str('a', 'b', 0x10000), true},
"non-printable prefix db": {str(0x0001, 'b', 'c'), "a", true},
"non-printable middle db": {str('a', 0x0001, 'c'), "a", true},
"non-printable suffix db": {str('a', 'b', 0x0001), "a", true},
"non-printable prefix table": {"a", str(0x0001, 'b', 'c'), true},
"non-printable middle table": {"a", str('a', 0x0001, 'c'), true},
"non-printable suffix table": {"a", str('a', 'b', 0x0001), true},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
err := validateDBTable(test.database, test.table)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
})
}
}
func str(r ...rune) string {
return string(r)
}

View File

@ -1,474 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package postgresql
import (
"context"
"database/sql"
"fmt"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/armon/go-metrics"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/vault/sdk/physical"
_ "github.com/jackc/pgx/v4/stdlib"
)
const (
// The lock TTL matches the default that Consul API uses, 15 seconds.
// Used as part of SQL commands to set/extend lock expiry time relative to
// database clock.
PostgreSQLLockTTLSeconds = 15
// The amount of time to wait between the lock renewals
PostgreSQLLockRenewInterval = 5 * time.Second
// PostgreSQLLockRetryInterval is the amount of time to wait
// if a lock fails before trying again.
PostgreSQLLockRetryInterval = time.Second
)
// Verify PostgreSQLBackend satisfies the correct interfaces
var _ physical.Backend = (*PostgreSQLBackend)(nil)
// HA backend was implemented based on the DynamoDB backend pattern
// With distinction using central postgres clock, hereby avoiding
// possible issues with multiple clocks
var (
_ physical.HABackend = (*PostgreSQLBackend)(nil)
_ physical.Lock = (*PostgreSQLLock)(nil)
)
// PostgreSQL Backend is a physical backend that stores data
// within a PostgreSQL database.
type PostgreSQLBackend struct {
table string
client *sql.DB
put_query string
get_query string
delete_query string
list_query string
ha_table string
haGetLockValueQuery string
haUpsertLockIdentityExec string
haDeleteLockExec string
haEnabled bool
logger log.Logger
permitPool *physical.PermitPool
}
// PostgreSQLLock implements a lock using an PostgreSQL client.
type PostgreSQLLock struct {
backend *PostgreSQLBackend
value, key string
identity string
lock sync.Mutex
renewTicker *time.Ticker
// ttlSeconds is how long a lock is valid for
ttlSeconds int
// renewInterval is how much time to wait between lock renewals. must be << ttl
renewInterval time.Duration
// retryInterval is how much time to wait between attempts to grab the lock
retryInterval time.Duration
}
// NewPostgreSQLBackend constructs a PostgreSQL backend using the given
// API client, server address, credentials, and database.
func NewPostgreSQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
// Get the PostgreSQL credentials to perform read/write operations.
connURL := connectionURL(conf)
if connURL == "" {
return nil, fmt.Errorf("missing connection_url")
}
unquoted_table, ok := conf["table"]
if !ok {
unquoted_table = "vault_kv_store"
}
quoted_table := dbutil.QuoteIdentifier(unquoted_table)
maxParStr, ok := conf["max_parallel"]
var maxParInt int
var err error
if ok {
maxParInt, err = strconv.Atoi(maxParStr)
if err != nil {
return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err)
}
if logger.IsDebug() {
logger.Debug("max_parallel set", "max_parallel", maxParInt)
}
} else {
maxParInt = physical.DefaultParallelOperations
}
maxIdleConnsStr, maxIdleConnsIsSet := conf["max_idle_connections"]
var maxIdleConns int
if maxIdleConnsIsSet {
maxIdleConns, err = strconv.Atoi(maxIdleConnsStr)
if err != nil {
return nil, fmt.Errorf("failed parsing max_idle_connections parameter: %w", err)
}
if logger.IsDebug() {
logger.Debug("max_idle_connections set", "max_idle_connections", maxIdleConnsStr)
}
}
// Create PostgreSQL handle for the database.
db, err := sql.Open("pgx", connURL)
if err != nil {
return nil, fmt.Errorf("failed to connect to postgres: %w", err)
}
db.SetMaxOpenConns(maxParInt)
if maxIdleConnsIsSet {
db.SetMaxIdleConns(maxIdleConns)
}
// Determine if we should use a function to work around lack of upsert (versions < 9.5)
var upsertAvailable bool
upsertAvailableQuery := "SELECT current_setting('server_version_num')::int >= 90500"
if err := db.QueryRow(upsertAvailableQuery).Scan(&upsertAvailable); err != nil {
return nil, fmt.Errorf("failed to check for native upsert: %w", err)
}
if !upsertAvailable && conf["ha_enabled"] == "true" {
return nil, fmt.Errorf("ha_enabled=true in config but PG version doesn't support HA, must be at least 9.5")
}
// Setup our put strategy based on the presence or absence of a native
// upsert.
var put_query string
if !upsertAvailable {
put_query = "SELECT vault_kv_put($1, $2, $3, $4)"
} else {
put_query = "INSERT INTO " + quoted_table + " VALUES($1, $2, $3, $4)" +
" ON CONFLICT (path, key) DO " +
" UPDATE SET (parent_path, path, key, value) = ($1, $2, $3, $4)"
}
unquoted_ha_table, ok := conf["ha_table"]
if !ok {
unquoted_ha_table = "vault_ha_locks"
}
quoted_ha_table := dbutil.QuoteIdentifier(unquoted_ha_table)
// Setup the backend.
m := &PostgreSQLBackend{
table: quoted_table,
client: db,
put_query: put_query,
get_query: "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2",
delete_query: "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2",
list_query: "SELECT key FROM " + quoted_table + " WHERE path = $1" +
" UNION ALL SELECT DISTINCT substring(substr(path, length($1)+1) from '^.*?/') FROM " + quoted_table +
" WHERE parent_path LIKE $1 || '%'",
haGetLockValueQuery:
// only read non expired data
" SELECT ha_value FROM " + quoted_ha_table + " WHERE NOW() <= valid_until AND ha_key = $1 ",
haUpsertLockIdentityExec:
// $1=identity $2=ha_key $3=ha_value $4=TTL in seconds
// update either steal expired lock OR update expiry for lock owned by me
" INSERT INTO " + quoted_ha_table + " as t (ha_identity, ha_key, ha_value, valid_until) VALUES ($1, $2, $3, NOW() + $4 * INTERVAL '1 seconds' ) " +
" ON CONFLICT (ha_key) DO " +
" UPDATE SET (ha_identity, ha_key, ha_value, valid_until) = ($1, $2, $3, NOW() + $4 * INTERVAL '1 seconds') " +
" WHERE (t.valid_until < NOW() AND t.ha_key = $2) OR " +
" (t.ha_identity = $1 AND t.ha_key = $2) ",
haDeleteLockExec:
// $1=ha_identity $2=ha_key
" DELETE FROM " + quoted_ha_table + " WHERE ha_identity=$1 AND ha_key=$2 ",
logger: logger,
permitPool: physical.NewPermitPool(maxParInt),
haEnabled: conf["ha_enabled"] == "true",
}
return m, nil
}
// connectionURL first check the environment variables for a connection URL. If
// no connection URL exists in the environment variable, the Vault config file is
// checked. If neither the environment variables or the config file set the connection
// URL for the Postgres backend, because it is a required field, an error is returned.
func connectionURL(conf map[string]string) string {
connURL := conf["connection_url"]
if envURL := os.Getenv("VAULT_PG_CONNECTION_URL"); envURL != "" {
connURL = envURL
}
return connURL
}
// splitKey is a helper to split a full path key into individual
// parts: parentPath, path, key
func (m *PostgreSQLBackend) splitKey(fullPath string) (string, string, string) {
var parentPath string
var path string
pieces := strings.Split(fullPath, "/")
depth := len(pieces)
key := pieces[depth-1]
if depth == 1 {
parentPath = ""
path = "/"
} else if depth == 2 {
parentPath = "/"
path = "/" + pieces[0] + "/"
} else {
parentPath = "/" + strings.Join(pieces[:depth-2], "/") + "/"
path = "/" + strings.Join(pieces[:depth-1], "/") + "/"
}
return parentPath, path, key
}
// Put is used to insert or update an entry.
func (m *PostgreSQLBackend) Put(ctx context.Context, entry *physical.Entry) error {
defer metrics.MeasureSince([]string{"postgres", "put"}, time.Now())
m.permitPool.Acquire()
defer m.permitPool.Release()
parentPath, path, key := m.splitKey(entry.Key)
_, err := m.client.ExecContext(ctx, m.put_query, parentPath, path, key, entry.Value)
if err != nil {
return err
}
return nil
}
// Get is used to fetch and entry.
func (m *PostgreSQLBackend) Get(ctx context.Context, fullPath string) (*physical.Entry, error) {
defer metrics.MeasureSince([]string{"postgres", "get"}, time.Now())
m.permitPool.Acquire()
defer m.permitPool.Release()
_, path, key := m.splitKey(fullPath)
var result []byte
err := m.client.QueryRowContext(ctx, m.get_query, path, key).Scan(&result)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
ent := &physical.Entry{
Key: fullPath,
Value: result,
}
return ent, nil
}
// Delete is used to permanently delete an entry
func (m *PostgreSQLBackend) Delete(ctx context.Context, fullPath string) error {
defer metrics.MeasureSince([]string{"postgres", "delete"}, time.Now())
m.permitPool.Acquire()
defer m.permitPool.Release()
_, path, key := m.splitKey(fullPath)
_, err := m.client.ExecContext(ctx, m.delete_query, path, key)
if err != nil {
return err
}
return nil
}
// List is used to list all the keys under a given
// prefix, up to the next prefix.
func (m *PostgreSQLBackend) List(ctx context.Context, prefix string) ([]string, error) {
defer metrics.MeasureSince([]string{"postgres", "list"}, time.Now())
m.permitPool.Acquire()
defer m.permitPool.Release()
rows, err := m.client.QueryContext(ctx, m.list_query, "/"+prefix)
if err != nil {
return nil, err
}
defer rows.Close()
var keys []string
for rows.Next() {
var key string
err = rows.Scan(&key)
if err != nil {
return nil, fmt.Errorf("failed to scan rows: %w", err)
}
keys = append(keys, key)
}
return keys, nil
}
// LockWith is used for mutual exclusion based on the given key.
func (p *PostgreSQLBackend) LockWith(key, value string) (physical.Lock, error) {
identity, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
return &PostgreSQLLock{
backend: p,
key: key,
value: value,
identity: identity,
ttlSeconds: PostgreSQLLockTTLSeconds,
renewInterval: PostgreSQLLockRenewInterval,
retryInterval: PostgreSQLLockRetryInterval,
}, nil
}
func (p *PostgreSQLBackend) HAEnabled() bool {
return p.haEnabled
}
// Lock tries to acquire the lock by repeatedly trying to create a record in the
// PostgreSQL table. It will block until either the stop channel is closed or
// the lock could be acquired successfully. The returned channel will be closed
// once the lock in the PostgreSQL table cannot be renewed, either due to an
// error speaking to PostgreSQL or because someone else has taken it.
func (l *PostgreSQLLock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) {
l.lock.Lock()
defer l.lock.Unlock()
var (
success = make(chan struct{})
errors = make(chan error)
leader = make(chan struct{})
)
// try to acquire the lock asynchronously
go l.tryToLock(stopCh, success, errors)
select {
case <-success:
// after acquiring it successfully, we must renew the lock periodically
l.renewTicker = time.NewTicker(l.renewInterval)
go l.periodicallyRenewLock(leader)
case err := <-errors:
return nil, err
case <-stopCh:
return nil, nil
}
return leader, nil
}
// Unlock releases the lock by deleting the lock record from the
// PostgreSQL table.
func (l *PostgreSQLLock) Unlock() error {
pg := l.backend
pg.permitPool.Acquire()
defer pg.permitPool.Release()
if l.renewTicker != nil {
l.renewTicker.Stop()
}
// Delete lock owned by me
_, err := pg.client.Exec(pg.haDeleteLockExec, l.identity, l.key)
return err
}
// Value checks whether or not the lock is held by any instance of PostgreSQLLock,
// including this one, and returns the current value.
func (l *PostgreSQLLock) Value() (bool, string, error) {
pg := l.backend
pg.permitPool.Acquire()
defer pg.permitPool.Release()
var result string
err := pg.client.QueryRow(pg.haGetLockValueQuery, l.key).Scan(&result)
switch err {
case nil:
return true, result, nil
case sql.ErrNoRows:
return false, "", nil
default:
return false, "", err
}
}
// tryToLock tries to create a new item in PostgreSQL every `retryInterval`.
// As long as the item cannot be created (because it already exists), it will
// be retried. If the operation fails due to an error, it is sent to the errors
// channel. When the lock could be acquired successfully, the success channel
// is closed.
func (l *PostgreSQLLock) tryToLock(stop <-chan struct{}, success chan struct{}, errors chan error) {
ticker := time.NewTicker(l.retryInterval)
defer ticker.Stop()
for {
select {
case <-stop:
return
case <-ticker.C:
gotlock, err := l.writeItem()
switch {
case err != nil:
errors <- err
return
case gotlock:
close(success)
return
}
}
}
}
func (l *PostgreSQLLock) periodicallyRenewLock(done chan struct{}) {
for range l.renewTicker.C {
gotlock, err := l.writeItem()
if err != nil || !gotlock {
close(done)
l.renewTicker.Stop()
return
}
}
}
// Attempts to put/update the PostgreSQL item using condition expressions to
// evaluate the TTL. Returns true if the lock was obtained, false if not.
// If false error may be nil or non-nil: nil indicates simply that someone
// else has the lock, whereas non-nil means that something unexpected happened.
func (l *PostgreSQLLock) writeItem() (bool, error) {
pg := l.backend
pg.permitPool.Acquire()
defer pg.permitPool.Release()
// Try steal lock or update expiry on my lock
sqlResult, err := pg.client.Exec(pg.haUpsertLockIdentityExec, l.identity, l.key, l.value, l.ttlSeconds)
if err != nil {
return false, err
}
if sqlResult == nil {
return false, fmt.Errorf("empty SQL response received")
}
ar, err := sqlResult.RowsAffected()
if err != nil {
return false, err
}
return ar == 1, nil
}

View File

@ -1,426 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package postgresql
import (
"fmt"
"os"
"testing"
"time"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/helper/testhelpers/postgresql"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/physical"
_ "github.com/jackc/pgx/v4/stdlib"
)
func TestPostgreSQLBackend(t *testing.T) {
logger := logging.NewVaultLogger(log.Debug)
// Use docker as pg backend if no url is provided via environment variables
connURL := os.Getenv("PGURL")
if connURL == "" {
cleanup, u := postgresql.PrepareTestContainer(t, "11.1")
defer cleanup()
connURL = u
}
table := os.Getenv("PGTABLE")
if table == "" {
table = "vault_kv_store"
}
hae := os.Getenv("PGHAENABLED")
if hae == "" {
hae = "true"
}
// Run vault tests
logger.Info(fmt.Sprintf("Connection URL: %v", connURL))
b1, err := NewPostgreSQLBackend(map[string]string{
"connection_url": connURL,
"table": table,
"ha_enabled": hae,
}, logger)
if err != nil {
t.Fatalf("Failed to create new backend: %v", err)
}
b2, err := NewPostgreSQLBackend(map[string]string{
"connection_url": connURL,
"table": table,
"ha_enabled": hae,
}, logger)
if err != nil {
t.Fatalf("Failed to create new backend: %v", err)
}
pg := b1.(*PostgreSQLBackend)
// Read postgres version to test basic connects works
var pgversion string
if err = pg.client.QueryRow("SELECT current_setting('server_version_num')").Scan(&pgversion); err != nil {
t.Fatalf("Failed to check for Postgres version: %v", err)
}
logger.Info(fmt.Sprintf("Postgres Version: %v", pgversion))
setupDatabaseObjects(t, logger, pg)
defer func() {
pg := b1.(*PostgreSQLBackend)
_, err := pg.client.Exec(fmt.Sprintf(" TRUNCATE TABLE %v ", pg.table))
if err != nil {
t.Fatalf("Failed to truncate table: %v", err)
}
}()
logger.Info("Running basic backend tests")
physical.ExerciseBackend(t, b1)
logger.Info("Running list prefix backend tests")
physical.ExerciseBackend_ListPrefix(t, b1)
ha1, ok := b1.(physical.HABackend)
if !ok {
t.Fatalf("PostgreSQLDB does not implement HABackend")
}
ha2, ok := b2.(physical.HABackend)
if !ok {
t.Fatalf("PostgreSQLDB does not implement HABackend")
}
if ha1.HAEnabled() && ha2.HAEnabled() {
logger.Info("Running ha backend tests")
physical.ExerciseHABackend(t, ha1, ha2)
testPostgresSQLLockTTL(t, ha1)
testPostgresSQLLockRenewal(t, ha1)
}
}
func TestPostgreSQLBackendMaxIdleConnectionsParameter(t *testing.T) {
_, err := NewPostgreSQLBackend(map[string]string{
"connection_url": "some connection url",
"max_idle_connections": "bad param",
}, logging.NewVaultLogger(log.Debug))
if err == nil {
t.Error("Expected invalid max_idle_connections param to return error")
}
expectedErrStr := "failed parsing max_idle_connections parameter: strconv.Atoi: parsing \"bad param\": invalid syntax"
if err.Error() != expectedErrStr {
t.Errorf("Expected: %q but found %q", expectedErrStr, err.Error())
}
}
func TestConnectionURL(t *testing.T) {
type input struct {
envar string
conf map[string]string
}
cases := map[string]struct {
want string
input input
}{
"environment_variable_not_set_use_config_value": {
want: "abc",
input: input{
envar: "",
conf: map[string]string{"connection_url": "abc"},
},
},
"no_value_connection_url_set_key_exists": {
want: "",
input: input{
envar: "",
conf: map[string]string{"connection_url": ""},
},
},
"no_value_connection_url_set_key_doesnt_exist": {
want: "",
input: input{
envar: "",
conf: map[string]string{},
},
},
"environment_variable_set": {
want: "abc",
input: input{
envar: "abc",
conf: map[string]string{"connection_url": "def"},
},
},
}
for name, tt := range cases {
t.Run(name, func(t *testing.T) {
// This is necessary to avoid always testing the branch where the env is set.
// As long the env is set --- even if the value is "" --- `ok` returns true.
if tt.input.envar != "" {
os.Setenv("VAULT_PG_CONNECTION_URL", tt.input.envar)
defer os.Unsetenv("VAULT_PG_CONNECTION_URL")
}
got := connectionURL(tt.input.conf)
if got != tt.want {
t.Errorf("connectionURL(%s): want %q, got %q", tt.input, tt.want, got)
}
})
}
}
// Similar to testHABackend, but using internal implementation details to
// trigger the lock failure scenario by setting the lock renew period for one
// of the locks to a higher value than the lock TTL.
const maxTries = 3
func testPostgresSQLLockTTL(t *testing.T, ha physical.HABackend) {
t.Log("Skipping testPostgresSQLLockTTL portion of test.")
return
for tries := 1; tries <= maxTries; tries++ {
// Try this several times. If the test environment is too slow the lock can naturally lapse
if attemptLockTTLTest(t, ha, tries) {
break
}
}
}
func attemptLockTTLTest(t *testing.T, ha physical.HABackend, tries int) bool {
// Set much smaller lock times to speed up the test.
lockTTL := 3
renewInterval := time.Second * 1
retryInterval := time.Second * 1
longRenewInterval := time.Duration(lockTTL*2) * time.Second
lockkey := "postgresttl"
var leaderCh <-chan struct{}
// Get the lock
origLock, err := ha.LockWith(lockkey, "bar")
if err != nil {
t.Fatalf("err: %v", err)
}
{
// set the first lock renew period to double the expected TTL.
lock := origLock.(*PostgreSQLLock)
lock.renewInterval = longRenewInterval
lock.ttlSeconds = lockTTL
// Attempt to lock
lockTime := time.Now()
leaderCh, err = lock.Lock(nil)
if err != nil {
t.Fatalf("err: %v", err)
}
if leaderCh == nil {
t.Fatalf("failed to get leader ch")
}
if tries == 1 {
time.Sleep(3 * time.Second)
}
// Check the value
held, val, err := lock.Value()
if err != nil {
t.Fatalf("err: %v", err)
}
if !held {
if tries < maxTries && time.Since(lockTime) > (time.Second*time.Duration(lockTTL)) {
// Our test environment is slow enough that we failed this, retry
return false
}
t.Fatalf("should be held")
}
if val != "bar" {
t.Fatalf("bad value: %v", val)
}
}
// Second acquisition should succeed because the first lock should
// not renew within the 3 sec TTL.
origLock2, err := ha.LockWith(lockkey, "baz")
if err != nil {
t.Fatalf("err: %v", err)
}
{
lock2 := origLock2.(*PostgreSQLLock)
lock2.renewInterval = renewInterval
lock2.ttlSeconds = lockTTL
lock2.retryInterval = retryInterval
// Cancel attempt in 6 sec so as not to block unit tests forever
stopCh := make(chan struct{})
time.AfterFunc(time.Duration(lockTTL*2)*time.Second, func() {
close(stopCh)
})
// Attempt to lock should work
lockTime := time.Now()
leaderCh2, err := lock2.Lock(stopCh)
if err != nil {
t.Fatalf("err: %v", err)
}
if leaderCh2 == nil {
t.Fatalf("should get leader ch")
}
defer lock2.Unlock()
// Check the value
held, val, err := lock2.Value()
if err != nil {
t.Fatalf("err: %v", err)
}
if !held {
if tries < maxTries && time.Since(lockTime) > (time.Second*time.Duration(lockTTL)) {
// Our test environment is slow enough that we failed this, retry
return false
}
t.Fatalf("should be held")
}
if val != "baz" {
t.Fatalf("bad value: %v", val)
}
}
// The first lock should have lost the leader channel
select {
case <-time.After(longRenewInterval * 2):
t.Fatalf("original lock did not have its leader channel closed.")
case <-leaderCh:
}
return true
}
// Verify that once Unlock is called, we don't keep trying to renew the original
// lock.
func testPostgresSQLLockRenewal(t *testing.T, ha physical.HABackend) {
// Get the lock
origLock, err := ha.LockWith("pgrenewal", "bar")
if err != nil {
t.Fatalf("err: %v", err)
}
// customize the renewal and watch intervals
lock := origLock.(*PostgreSQLLock)
// lock.renewInterval = time.Second * 1
// Attempt to lock
leaderCh, err := lock.Lock(nil)
if err != nil {
t.Fatalf("err: %v", err)
}
if leaderCh == nil {
t.Fatalf("failed to get leader ch")
}
// Check the value
held, val, err := lock.Value()
if err != nil {
t.Fatalf("err: %v", err)
}
if !held {
t.Fatalf("should be held")
}
if val != "bar" {
t.Fatalf("bad value: %v", val)
}
// Release the lock, which will delete the stored item
if err := lock.Unlock(); err != nil {
t.Fatalf("err: %v", err)
}
// Wait longer than the renewal time
time.Sleep(1500 * time.Millisecond)
// Attempt to lock with new lock
newLock, err := ha.LockWith("pgrenewal", "baz")
if err != nil {
t.Fatalf("err: %v", err)
}
stopCh := make(chan struct{})
timeout := time.Duration(lock.ttlSeconds)*time.Second + lock.retryInterval + time.Second
var leaderCh2 <-chan struct{}
newlockch := make(chan struct{})
go func() {
leaderCh2, err = newLock.Lock(stopCh)
close(newlockch)
}()
// Cancel attempt after lock ttl + 1s so as not to block unit tests forever
select {
case <-time.After(timeout):
t.Logf("giving up on lock attempt after %v", timeout)
close(stopCh)
case <-newlockch:
// pass through
}
// Attempt to lock should work
if err != nil {
t.Fatalf("err: %v", err)
}
if leaderCh2 == nil {
t.Fatalf("should get leader ch")
}
// Check the value
held, val, err = newLock.Value()
if err != nil {
t.Fatalf("err: %v", err)
}
if !held {
t.Fatalf("should be held")
}
if val != "baz" {
t.Fatalf("bad value: %v", val)
}
// Cleanup
newLock.Unlock()
}
func setupDatabaseObjects(t *testing.T, logger log.Logger, pg *PostgreSQLBackend) {
var err error
// Setup tables and indexes if not exists.
createTableSQL := fmt.Sprintf(
" CREATE TABLE IF NOT EXISTS %v ( "+
" parent_path TEXT COLLATE \"C\" NOT NULL, "+
" path TEXT COLLATE \"C\", "+
" key TEXT COLLATE \"C\", "+
" value BYTEA, "+
" CONSTRAINT pkey PRIMARY KEY (path, key) "+
" ); ", pg.table)
_, err = pg.client.Exec(createTableSQL)
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
createIndexSQL := fmt.Sprintf(" CREATE INDEX IF NOT EXISTS parent_path_idx ON %v (parent_path); ", pg.table)
_, err = pg.client.Exec(createIndexSQL)
if err != nil {
t.Fatalf("Failed to create index: %v", err)
}
createHaTableSQL := " CREATE TABLE IF NOT EXISTS vault_ha_locks ( " +
" ha_key TEXT COLLATE \"C\" NOT NULL, " +
" ha_identity TEXT COLLATE \"C\" NOT NULL, " +
" ha_value TEXT COLLATE \"C\", " +
" valid_until TIMESTAMP WITH TIME ZONE NOT NULL, " +
" CONSTRAINT ha_key PRIMARY KEY (ha_key) " +
" ); "
_, err = pg.client.Exec(createHaTableSQL)
if err != nil {
t.Fatalf("Failed to create hatable: %v", err)
}
}