Compare commits
3 Commits
f947ce9080
...
bea345a84c
Author | SHA1 | Date | |
---|---|---|---|
bea345a84c | |||
3fcad1ec13 | |||
13230a9749 |
|
@ -41,12 +41,9 @@ import (
|
|||
logicalKv "github.com/hashicorp/vault-plugin-secrets-kv"
|
||||
logicalDb "github.com/hashicorp/vault/builtin/logical/database"
|
||||
|
||||
physCockroachDB "github.com/hashicorp/vault/physical/cockroachdb"
|
||||
physConsul "github.com/hashicorp/vault/physical/consul"
|
||||
physFoundationDB "github.com/hashicorp/vault/physical/foundationdb"
|
||||
physMySQL "github.com/hashicorp/vault/physical/mysql"
|
||||
physOCI "github.com/hashicorp/vault/physical/oci"
|
||||
physPostgreSQL "github.com/hashicorp/vault/physical/postgresql"
|
||||
physRaft "github.com/hashicorp/vault/physical/raft"
|
||||
physFile "github.com/hashicorp/vault/sdk/physical/file"
|
||||
physInmem "github.com/hashicorp/vault/sdk/physical/inmem"
|
||||
|
@ -168,7 +165,6 @@ var (
|
|||
}
|
||||
|
||||
physicalBackends = map[string]physical.Factory{
|
||||
"cockroachdb": physCockroachDB.NewCockroachDBBackend,
|
||||
"consul": physConsul.NewConsulBackend,
|
||||
"file_transactional": physFile.NewTransactionalFileBackend,
|
||||
"file": physFile.NewFileBackend,
|
||||
|
@ -177,9 +173,7 @@ var (
|
|||
"inmem_transactional_ha": physInmem.NewTransactionalInmemHA,
|
||||
"inmem_transactional": physInmem.NewTransactionalInmem,
|
||||
"inmem": physInmem.NewInmem,
|
||||
"mysql": physMySQL.NewMySQLBackend,
|
||||
"oci": physOCI.NewBackend,
|
||||
"postgresql": physPostgreSQL.NewPostgreSQLBackend,
|
||||
"raft": physRaft.NewRaftBackend,
|
||||
}
|
||||
|
||||
|
|
2
go.mod
2
go.mod
|
@ -33,7 +33,6 @@ require (
|
|||
github.com/axiomhq/hyperloglog v0.0.0-20220105174342-98591331716a
|
||||
github.com/cenkalti/backoff/v3 v3.2.2
|
||||
github.com/chrismalek/oktasdk-go v0.0.0-20181212195951-3430665dfaa0
|
||||
github.com/cockroachdb/cockroach-go v0.0.0-20181001143604-e0a95dfd547c
|
||||
github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf
|
||||
github.com/duosecurity/duo_api_golang v0.0.0-20190308151101-6c680f768e74
|
||||
github.com/dustin/go-humanize v1.0.1
|
||||
|
@ -295,7 +294,6 @@ require (
|
|||
github.com/jackc/pgproto3/v2 v2.3.3 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||
github.com/jackc/pgtype v1.14.0 // indirect
|
||||
github.com/jackc/pgx v3.3.0+incompatible // indirect
|
||||
github.com/jcmturner/aescts/v2 v2.0.0 // indirect
|
||||
github.com/jcmturner/dnsutils/v2 v2.0.0 // indirect
|
||||
github.com/jcmturner/gofork v1.7.6 // indirect
|
||||
|
|
5
go.sum
5
go.sum
|
@ -1095,8 +1095,6 @@ github.com/cncf/xds/go v0.0.0-20230428030218-4003588d1b74/go.mod h1:eXthEFrGJvWH
|
|||
github.com/cncf/xds/go v0.0.0-20230607035331-e9ce68804cb4/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
|
||||
github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I=
|
||||
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
|
||||
github.com/cockroachdb/cockroach-go v0.0.0-20181001143604-e0a95dfd547c h1:2zRrJWIt/f9c9HhNHAgrRgq0San5gRRUJTBXLkchal0=
|
||||
github.com/cockroachdb/cockroach-go v0.0.0-20181001143604-e0a95dfd547c/go.mod h1:XGLbWH/ujMcbPbhZq52Nv6UrCghb1yGn//133kEsvDk=
|
||||
github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa/go.mod h1:zn76sxSg3SzpJ0PPJaLDCu+Bu0Lg3sKTORVIj19EIF8=
|
||||
github.com/cockroachdb/datadriven v0.0.0-20200714090401-bf6692d28da5/go.mod h1:h6jFvWxBdQXxjopDMZyH2UVceIRfR84bdzbkoKrsWNo=
|
||||
github.com/cockroachdb/errors v1.2.4/go.mod h1:rQD95gz6FARkaKkQXUksEje/d9a6wBJoCr5oaCLELYA=
|
||||
|
@ -2092,7 +2090,6 @@ github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9
|
|||
github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
|
||||
github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8=
|
||||
github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
|
||||
github.com/jackc/fake v0.0.0-20150926172116-812a484cc733 h1:vr3AYkKovP8uR8AvSGGUK1IDqRa5lAAvEkZG1LKaCRc=
|
||||
github.com/jackc/fake v0.0.0-20150926172116-812a484cc733/go.mod h1:WrMFNQdiFJ80sQsxDoMokWK1W5TQtxBFNpzWTD84ibQ=
|
||||
github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA=
|
||||
github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE=
|
||||
|
@ -2128,7 +2125,6 @@ github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrU
|
|||
github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM=
|
||||
github.com/jackc/pgtype v1.14.0 h1:y+xUdabmyMkJLyApYuPj38mW+aAIqCe5uuBB51rH3Vw=
|
||||
github.com/jackc/pgtype v1.14.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4=
|
||||
github.com/jackc/pgx v3.3.0+incompatible h1:Wa90/+qsITBAPkAZjiByeIGHFcj3Ztu+VzrrIpHjL90=
|
||||
github.com/jackc/pgx v3.3.0+incompatible/go.mod h1:0ZGrqGqkRlliWnWB4zKnWtjbSWbGkVEFm4TeybAXq+I=
|
||||
github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y=
|
||||
github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM=
|
||||
|
@ -2683,7 +2679,6 @@ github.com/safchain/ethtool v0.0.0-20210803160452-9aa261dae9b1/go.mod h1:Z0q5wiB
|
|||
github.com/safchain/ethtool v0.2.0/go.mod h1:WkKB1DnNtvsMlDmQ50sgwowDJV/hGbJSOvJoEXs1AJQ=
|
||||
github.com/sasha-s/go-deadlock v0.2.0 h1:lMqc+fUb7RrFS3gQLtoQsJ7/6TV/pAIFvBsqX73DK8Y=
|
||||
github.com/sasha-s/go-deadlock v0.2.0/go.mod h1:StQn567HiB1fF2yJ44N9au7wOhrPS3iZqiDbRupzT10=
|
||||
github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww=
|
||||
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
|
||||
github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw=
|
||||
github.com/sclevine/spec v1.2.0/go.mod h1:W4J29eT/Kzv7/b9IWLB055Z+qvVC9vt0Arko24q7p+U=
|
||||
|
|
|
@ -1,356 +0,0 @@
|
|||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
package cockroachdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
metrics "github.com/armon/go-metrics"
|
||||
"github.com/cockroachdb/cockroach-go/crdb"
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/hashicorp/go-secure-stdlib/strutil"
|
||||
"github.com/hashicorp/vault/sdk/physical"
|
||||
|
||||
// CockroachDB uses the Postgres SQL driver
|
||||
_ "github.com/jackc/pgx/v4/stdlib"
|
||||
)
|
||||
|
||||
// Verify CockroachDBBackend satisfies the correct interfaces
|
||||
var (
|
||||
_ physical.Backend = (*CockroachDBBackend)(nil)
|
||||
_ physical.Transactional = (*CockroachDBBackend)(nil)
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTableName = "vault_kv_store"
|
||||
defaultHATableName = "vault_ha_locks"
|
||||
)
|
||||
|
||||
// CockroachDBBackend Backend is a physical backend that stores data
|
||||
// within a CockroachDB database.
|
||||
type CockroachDBBackend struct {
|
||||
table string
|
||||
haTable string
|
||||
client *sql.DB
|
||||
rawStatements map[string]string
|
||||
statements map[string]*sql.Stmt
|
||||
rawHAStatements map[string]string
|
||||
haStatements map[string]*sql.Stmt
|
||||
logger log.Logger
|
||||
permitPool *physical.PermitPool
|
||||
haEnabled bool
|
||||
}
|
||||
|
||||
// NewCockroachDBBackend constructs a CockroachDB backend using the given
|
||||
// API client, server address, credentials, and database.
|
||||
func NewCockroachDBBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
|
||||
// Get the CockroachDB credentials to perform read/write operations.
|
||||
connURL, ok := conf["connection_url"]
|
||||
if !ok || connURL == "" {
|
||||
return nil, fmt.Errorf("missing connection_url")
|
||||
}
|
||||
|
||||
haEnabled := conf["ha_enabled"] == "true"
|
||||
|
||||
dbTable := conf["table"]
|
||||
if dbTable == "" {
|
||||
dbTable = defaultTableName
|
||||
}
|
||||
|
||||
err := validateDBTable(dbTable)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid table: %w", err)
|
||||
}
|
||||
|
||||
dbHATable, ok := conf["ha_table"]
|
||||
if !ok {
|
||||
dbHATable = defaultHATableName
|
||||
}
|
||||
|
||||
err = validateDBTable(dbHATable)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid HA table: %w", err)
|
||||
}
|
||||
|
||||
maxParStr, ok := conf["max_parallel"]
|
||||
var maxParInt int
|
||||
if ok {
|
||||
maxParInt, err = strconv.Atoi(maxParStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err)
|
||||
}
|
||||
if logger.IsDebug() {
|
||||
logger.Debug("max_parallel set", "max_parallel", maxParInt)
|
||||
}
|
||||
}
|
||||
|
||||
// Create CockroachDB handle for the database.
|
||||
db, err := sql.Open("pgx", connURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to cockroachdb: %w", err)
|
||||
}
|
||||
|
||||
// Create the required tables if they don't exist.
|
||||
createQuery := "CREATE TABLE IF NOT EXISTS " + dbTable +
|
||||
" (path STRING, value BYTES, PRIMARY KEY (path))"
|
||||
if _, err := db.Exec(createQuery); err != nil {
|
||||
return nil, fmt.Errorf("failed to create CockroachDB table: %w", err)
|
||||
}
|
||||
if haEnabled {
|
||||
createHATableQuery := "CREATE TABLE IF NOT EXISTS " + dbHATable +
|
||||
"(ha_key TEXT NOT NULL, " +
|
||||
" ha_identity TEXT NOT NULL, " +
|
||||
" ha_value TEXT, " +
|
||||
" valid_until TIMESTAMP WITH TIME ZONE NOT NULL, " +
|
||||
" CONSTRAINT ha_key PRIMARY KEY (ha_key) " +
|
||||
");"
|
||||
if _, err := db.Exec(createHATableQuery); err != nil {
|
||||
return nil, fmt.Errorf("failed to create CockroachDB HA table: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Setup the backend
|
||||
c := &CockroachDBBackend{
|
||||
table: dbTable,
|
||||
haTable: dbHATable,
|
||||
client: db,
|
||||
rawStatements: map[string]string{
|
||||
"put": "INSERT INTO " + dbTable + " VALUES($1, $2)" +
|
||||
" ON CONFLICT (path) DO " +
|
||||
" UPDATE SET (path, value) = ($1, $2)",
|
||||
"get": "SELECT value FROM " + dbTable + " WHERE path = $1",
|
||||
"delete": "DELETE FROM " + dbTable + " WHERE path = $1",
|
||||
"list": "SELECT path FROM " + dbTable + " WHERE path LIKE $1",
|
||||
},
|
||||
statements: make(map[string]*sql.Stmt),
|
||||
rawHAStatements: map[string]string{
|
||||
"get": "SELECT ha_value FROM " + dbHATable + " WHERE NOW() <= valid_until AND ha_key = $1",
|
||||
"upsert": "INSERT INTO " + dbHATable + " as t (ha_identity, ha_key, ha_value, valid_until)" +
|
||||
" VALUES ($1, $2, $3, NOW() + $4) " +
|
||||
" ON CONFLICT (ha_key) DO " +
|
||||
" UPDATE SET (ha_identity, ha_key, ha_value, valid_until) = ($1, $2, $3, NOW() + $4) " +
|
||||
" WHERE (t.valid_until < NOW() AND t.ha_key = $2) OR " +
|
||||
" (t.ha_identity = $1 AND t.ha_key = $2) ",
|
||||
"delete": "DELETE FROM " + dbHATable + " WHERE ha_key = $1",
|
||||
},
|
||||
haStatements: make(map[string]*sql.Stmt),
|
||||
logger: logger,
|
||||
permitPool: physical.NewPermitPool(maxParInt),
|
||||
haEnabled: haEnabled,
|
||||
}
|
||||
|
||||
// Prepare all the statements required
|
||||
for name, query := range c.rawStatements {
|
||||
if err := c.prepare(c.statements, name, query); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if haEnabled {
|
||||
for name, query := range c.rawHAStatements {
|
||||
if err := c.prepare(c.haStatements, name, query); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// prepare is a helper to prepare a query for future execution.
|
||||
func (c *CockroachDBBackend) prepare(statementMap map[string]*sql.Stmt, name, query string) error {
|
||||
stmt, err := c.client.Prepare(query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prepare %q: %w", name, err)
|
||||
}
|
||||
statementMap[name] = stmt
|
||||
return nil
|
||||
}
|
||||
|
||||
// Put is used to insert or update an entry.
|
||||
func (c *CockroachDBBackend) Put(ctx context.Context, entry *physical.Entry) error {
|
||||
defer metrics.MeasureSince([]string{"cockroachdb", "put"}, time.Now())
|
||||
|
||||
c.permitPool.Acquire()
|
||||
defer c.permitPool.Release()
|
||||
|
||||
_, err := c.statements["put"].Exec(entry.Key, entry.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get is used to fetch and entry.
|
||||
func (c *CockroachDBBackend) Get(ctx context.Context, key string) (*physical.Entry, error) {
|
||||
defer metrics.MeasureSince([]string{"cockroachdb", "get"}, time.Now())
|
||||
|
||||
c.permitPool.Acquire()
|
||||
defer c.permitPool.Release()
|
||||
|
||||
var result []byte
|
||||
err := c.statements["get"].QueryRow(key).Scan(&result)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ent := &physical.Entry{
|
||||
Key: key,
|
||||
Value: result,
|
||||
}
|
||||
return ent, nil
|
||||
}
|
||||
|
||||
// Delete is used to permanently delete an entry
|
||||
func (c *CockroachDBBackend) Delete(ctx context.Context, key string) error {
|
||||
defer metrics.MeasureSince([]string{"cockroachdb", "delete"}, time.Now())
|
||||
|
||||
c.permitPool.Acquire()
|
||||
defer c.permitPool.Release()
|
||||
|
||||
_, err := c.statements["delete"].Exec(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// List is used to list all the keys under a given
|
||||
// prefix, up to the next prefix.
|
||||
func (c *CockroachDBBackend) List(ctx context.Context, prefix string) ([]string, error) {
|
||||
defer metrics.MeasureSince([]string{"cockroachdb", "list"}, time.Now())
|
||||
|
||||
c.permitPool.Acquire()
|
||||
defer c.permitPool.Release()
|
||||
|
||||
likePrefix := prefix + "%"
|
||||
rows, err := c.statements["list"].Query(likePrefix)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var keys []string
|
||||
for rows.Next() {
|
||||
var key string
|
||||
err = rows.Scan(&key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan rows: %w", err)
|
||||
}
|
||||
|
||||
key = strings.TrimPrefix(key, prefix)
|
||||
if i := strings.Index(key, "/"); i == -1 {
|
||||
// Add objects only from the current 'folder'
|
||||
keys = append(keys, key)
|
||||
} else if i != -1 {
|
||||
// Add truncated 'folder' paths
|
||||
keys = strutil.AppendIfMissing(keys, string(key[:i+1]))
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(keys)
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// Transaction is used to run multiple entries via a transaction
|
||||
func (c *CockroachDBBackend) Transaction(ctx context.Context, txns []*physical.TxnEntry) error {
|
||||
defer metrics.MeasureSince([]string{"cockroachdb", "transaction"}, time.Now())
|
||||
if len(txns) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.permitPool.Acquire()
|
||||
defer c.permitPool.Release()
|
||||
|
||||
return crdb.ExecuteTx(context.Background(), c.client, nil, func(tx *sql.Tx) error {
|
||||
return c.transaction(tx, txns)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *CockroachDBBackend) transaction(tx *sql.Tx, txns []*physical.TxnEntry) error {
|
||||
deleteStmt, err := tx.Prepare(c.rawStatements["delete"])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
putStmt, err := tx.Prepare(c.rawStatements["put"])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, op := range txns {
|
||||
switch op.Operation {
|
||||
case physical.DeleteOperation:
|
||||
_, err = deleteStmt.Exec(op.Entry.Key)
|
||||
case physical.PutOperation:
|
||||
_, err = putStmt.Exec(op.Entry.Key, op.Entry.Value)
|
||||
default:
|
||||
return fmt.Errorf("%q is not a supported transaction operation", op.Operation)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateDBTable against the CockroachDB rules for table names:
|
||||
// https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#identifiers
|
||||
//
|
||||
// - All values that accept an identifier must:
|
||||
// - Begin with a Unicode letter or an underscore (_). Subsequent characters can be letters,
|
||||
// - underscores, digits (0-9), or dollar signs ($).
|
||||
// - Not equal any SQL keyword unless the keyword is accepted by the element's syntax. For example,
|
||||
// name accepts Unreserved or Column Name keywords.
|
||||
//
|
||||
// The docs do state that we can bypass these rules with double quotes, however I think it
|
||||
// is safer to just require these rules across the board.
|
||||
func validateDBTable(dbTable string) (err error) {
|
||||
// Check if this is 'database.table' formatted. If so, split them apart and check the two
|
||||
// parts from each other
|
||||
split := strings.SplitN(dbTable, ".", 2)
|
||||
if len(split) == 2 {
|
||||
merr := &multierror.Error{}
|
||||
merr = multierror.Append(merr, wrapErr("invalid database: %w", validateDBTable(split[0])))
|
||||
merr = multierror.Append(merr, wrapErr("invalid table name: %w", validateDBTable(split[1])))
|
||||
return merr.ErrorOrNil()
|
||||
}
|
||||
|
||||
// Disallow SQL keywords as the table name
|
||||
if sqlKeywords[strings.ToUpper(dbTable)] {
|
||||
return fmt.Errorf("name must not be a SQL keyword")
|
||||
}
|
||||
|
||||
runes := []rune(dbTable)
|
||||
for i, r := range runes {
|
||||
if i == 0 && !unicode.IsLetter(r) && r != '_' {
|
||||
return fmt.Errorf("must use a letter or an underscore as the first character")
|
||||
}
|
||||
|
||||
if !unicode.IsLetter(r) && r != '_' && !unicode.IsDigit(r) && r != '$' {
|
||||
return fmt.Errorf("must only contain letters, underscores, digits, and dollar signs")
|
||||
}
|
||||
|
||||
if r == '`' || r == '\'' || r == '"' {
|
||||
return fmt.Errorf("cannot contain backticks, single quotes, or double quotes")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func wrapErr(message string, err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf(message, err)
|
||||
}
|
|
@ -1,204 +0,0 @@
|
|||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
package cockroachdb
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/sdk/physical"
|
||||
)
|
||||
|
||||
const (
|
||||
// The lock TTL matches the default that Consul API uses, 15 seconds.
|
||||
// Used as part of SQL commands to set/extend lock expiry time relative to
|
||||
// database clock.
|
||||
CockroachDBLockTTLSeconds = 15
|
||||
|
||||
// The amount of time to wait between the lock renewals
|
||||
CockroachDBLockRenewInterval = 5 * time.Second
|
||||
|
||||
// CockroachDBLockRetryInterval is the amount of time to wait
|
||||
// if a lock fails before trying again.
|
||||
CockroachDBLockRetryInterval = time.Second
|
||||
)
|
||||
|
||||
// Verify backend satisfies the correct interfaces.
|
||||
var (
|
||||
_ physical.HABackend = (*CockroachDBBackend)(nil)
|
||||
_ physical.Lock = (*CockroachDBLock)(nil)
|
||||
)
|
||||
|
||||
type CockroachDBLock struct {
|
||||
backend *CockroachDBBackend
|
||||
key string
|
||||
value string
|
||||
identity string
|
||||
lock sync.Mutex
|
||||
|
||||
renewTicker *time.Ticker
|
||||
|
||||
// ttlSeconds is how long a lock is valid for.
|
||||
ttlSeconds int
|
||||
|
||||
// renewInterval is how much time to wait between lock renewals. must be << ttl.
|
||||
renewInterval time.Duration
|
||||
|
||||
// retryInterval is how much time to wait between attempts to grab the lock.
|
||||
retryInterval time.Duration
|
||||
}
|
||||
|
||||
func (c *CockroachDBBackend) HAEnabled() bool {
|
||||
return c.haEnabled
|
||||
}
|
||||
|
||||
func (c *CockroachDBBackend) LockWith(key, value string) (physical.Lock, error) {
|
||||
identity, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &CockroachDBLock{
|
||||
backend: c,
|
||||
key: key,
|
||||
value: value,
|
||||
identity: identity,
|
||||
ttlSeconds: CockroachDBLockTTLSeconds,
|
||||
renewInterval: CockroachDBLockRenewInterval,
|
||||
retryInterval: CockroachDBLockRetryInterval,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Lock tries to acquire the lock by repeatedly trying to create a record in the
|
||||
// CockroachDB table. It will block until either the stop channel is closed or
|
||||
// the lock could be acquired successfully. The returned channel will be closed
|
||||
// once the lock in the CockroachDB table cannot be renewed, either due to an
|
||||
// error speaking to CockroachDB or because someone else has taken it.
|
||||
func (l *CockroachDBLock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) {
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
|
||||
var (
|
||||
success = make(chan struct{})
|
||||
errors = make(chan error, 1)
|
||||
leader = make(chan struct{})
|
||||
)
|
||||
go l.tryToLock(stopCh, success, errors)
|
||||
|
||||
select {
|
||||
case <-success:
|
||||
// After acquiring it successfully, we must renew the lock periodically.
|
||||
l.renewTicker = time.NewTicker(l.renewInterval)
|
||||
go l.periodicallyRenewLock(leader)
|
||||
case err := <-errors:
|
||||
return nil, err
|
||||
case <-stopCh:
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return leader, nil
|
||||
}
|
||||
|
||||
// Unlock releases the lock by deleting the lock record from the
|
||||
// CockroachDB table.
|
||||
func (l *CockroachDBLock) Unlock() error {
|
||||
c := l.backend
|
||||
c.permitPool.Acquire()
|
||||
defer c.permitPool.Release()
|
||||
|
||||
if l.renewTicker != nil {
|
||||
l.renewTicker.Stop()
|
||||
}
|
||||
|
||||
_, err := c.haStatements["delete"].Exec(l.key)
|
||||
return err
|
||||
}
|
||||
|
||||
// Value checks whether or not the lock is held by any instance of CockroachDBLock,
|
||||
// including this one, and returns the current value.
|
||||
func (l *CockroachDBLock) Value() (bool, string, error) {
|
||||
c := l.backend
|
||||
c.permitPool.Acquire()
|
||||
defer c.permitPool.Release()
|
||||
var result string
|
||||
err := c.haStatements["get"].QueryRow(l.key).Scan(&result)
|
||||
|
||||
switch err {
|
||||
case nil:
|
||||
return true, result, nil
|
||||
case sql.ErrNoRows:
|
||||
return false, "", nil
|
||||
default:
|
||||
return false, "", err
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// tryToLock tries to create a new item in CockroachDB every `retryInterval`.
|
||||
// As long as the item cannot be created (because it already exists), it will
|
||||
// be retried. If the operation fails due to an error, it is sent to the errors
|
||||
// channel. When the lock could be acquired successfully, the success channel
|
||||
// is closed.
|
||||
func (l *CockroachDBLock) tryToLock(stop <-chan struct{}, success chan struct{}, errors chan error) {
|
||||
ticker := time.NewTicker(l.retryInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case <-ticker.C:
|
||||
gotlock, err := l.writeItem()
|
||||
switch {
|
||||
case err != nil:
|
||||
// Send to the error channel and don't block if full.
|
||||
select {
|
||||
case errors <- err:
|
||||
default:
|
||||
}
|
||||
return
|
||||
case gotlock:
|
||||
close(success)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *CockroachDBLock) periodicallyRenewLock(done chan struct{}) {
|
||||
for range l.renewTicker.C {
|
||||
gotlock, err := l.writeItem()
|
||||
if err != nil || !gotlock {
|
||||
close(done)
|
||||
l.renewTicker.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Attempts to put/update the CockroachDB item using condition expressions to
|
||||
// evaluate the TTL. Returns true if the lock was obtained, false if not.
|
||||
// If false error may be nil or non-nil: nil indicates simply that someone
|
||||
// else has the lock, whereas non-nil means that something unexpected happened.
|
||||
func (l *CockroachDBLock) writeItem() (bool, error) {
|
||||
c := l.backend
|
||||
c.permitPool.Acquire()
|
||||
defer c.permitPool.Release()
|
||||
|
||||
sqlResult, err := c.haStatements["upsert"].Exec(l.identity, l.key, l.value, fmt.Sprintf("%d seconds", l.ttlSeconds))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if sqlResult == nil {
|
||||
return false, fmt.Errorf("empty SQL response received")
|
||||
}
|
||||
|
||||
ar, err := sqlResult.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return ar == 1, nil
|
||||
}
|
|
@ -1,214 +0,0 @@
|
|||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
package cockroachdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/sdk/helper/docker"
|
||||
"github.com/hashicorp/vault/sdk/helper/logging"
|
||||
"github.com/hashicorp/vault/sdk/physical"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
docker.ServiceURL
|
||||
TableName string
|
||||
HATableName string
|
||||
}
|
||||
|
||||
var _ docker.ServiceConfig = &Config{}
|
||||
|
||||
func prepareCockroachDBTestContainer(t *testing.T) (func(), *Config) {
|
||||
if retURL := os.Getenv("CR_URL"); retURL != "" {
|
||||
s, err := docker.NewServiceURLParse(retURL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return func() {}, &Config{
|
||||
ServiceURL: *s,
|
||||
TableName: "vault." + defaultTableName,
|
||||
HATableName: "vault." + defaultHATableName,
|
||||
}
|
||||
}
|
||||
|
||||
runner, err := docker.NewServiceRunner(docker.RunOptions{
|
||||
ImageRepo: "docker.mirror.hashicorp.services/cockroachdb/cockroach",
|
||||
ImageTag: "release-1.0",
|
||||
ContainerName: "cockroachdb",
|
||||
Cmd: []string{"start", "--insecure"},
|
||||
Ports: []string{"26257/tcp"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Could not start docker CockroachDB: %s", err)
|
||||
}
|
||||
svc, err := runner.StartService(context.Background(), connectCockroachDB)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not start docker CockroachDB: %s", err)
|
||||
}
|
||||
|
||||
return svc.Cleanup, svc.Config.(*Config)
|
||||
}
|
||||
|
||||
func connectCockroachDB(ctx context.Context, host string, port int) (docker.ServiceConfig, error) {
|
||||
u := url.URL{
|
||||
Scheme: "postgresql",
|
||||
User: url.UserPassword("root", ""),
|
||||
Host: fmt.Sprintf("%s:%d", host, port),
|
||||
RawQuery: "sslmode=disable",
|
||||
}
|
||||
|
||||
db, err := sql.Open("pgx", u.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
database := "vault"
|
||||
_, err = db.Exec(fmt.Sprintf("CREATE DATABASE %s", database))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Config{
|
||||
ServiceURL: *docker.NewServiceURL(u),
|
||||
TableName: database + "." + defaultTableName,
|
||||
HATableName: database + "." + defaultHATableName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestCockroachDBBackend(t *testing.T) {
|
||||
cleanup, config := prepareCockroachDBTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
hae := os.Getenv("CR_HA_ENABLED")
|
||||
if hae == "" {
|
||||
hae = "true"
|
||||
}
|
||||
|
||||
// Run vault tests
|
||||
logger := logging.NewVaultLogger(log.Debug)
|
||||
|
||||
b1, err := NewCockroachDBBackend(map[string]string{
|
||||
"connection_url": config.URL().String(),
|
||||
"table": config.TableName,
|
||||
"ha_table": config.HATableName,
|
||||
"ha_enabled": hae,
|
||||
}, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create new backend: %v", err)
|
||||
}
|
||||
|
||||
b2, err := NewCockroachDBBackend(map[string]string{
|
||||
"connection_url": config.URL().String(),
|
||||
"table": config.TableName,
|
||||
"ha_table": config.HATableName,
|
||||
"ha_enabled": hae,
|
||||
}, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create new backend: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
truncate(t, b1)
|
||||
truncate(t, b2)
|
||||
}()
|
||||
|
||||
physical.ExerciseBackend(t, b1)
|
||||
truncate(t, b1)
|
||||
physical.ExerciseBackend_ListPrefix(t, b1)
|
||||
truncate(t, b1)
|
||||
physical.ExerciseTransactionalBackend(t, b1)
|
||||
truncate(t, b1)
|
||||
|
||||
ha1, ok1 := b1.(physical.HABackend)
|
||||
ha2, ok2 := b2.(physical.HABackend)
|
||||
if !ok1 || !ok2 {
|
||||
t.Fatalf("CockroachDB does not implement HABackend")
|
||||
}
|
||||
|
||||
if ha1.HAEnabled() && ha2.HAEnabled() {
|
||||
logger.Info("Running ha backend tests")
|
||||
physical.ExerciseHABackend(t, ha1, ha2)
|
||||
}
|
||||
}
|
||||
|
||||
func truncate(t *testing.T, b physical.Backend) {
|
||||
crdb := b.(*CockroachDBBackend)
|
||||
_, err := crdb.client.Exec("TRUNCATE TABLE " + crdb.table)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to drop table: %v", err)
|
||||
}
|
||||
if crdb.haEnabled {
|
||||
_, err = crdb.client.Exec("TRUNCATE TABLE " + crdb.haTable)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to drop table: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateDBTable(t *testing.T) {
|
||||
type testCase struct {
|
||||
table string
|
||||
expectErr bool
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
"first character is letter": {"abcdef", false},
|
||||
"first character is underscore": {"_bcdef", false},
|
||||
"exclamation point": {"ab!def", true},
|
||||
"at symbol": {"ab@def", true},
|
||||
"hash": {"ab#def", true},
|
||||
"percent": {"ab%def", true},
|
||||
"carrot": {"ab^def", true},
|
||||
"ampersand": {"ab&def", true},
|
||||
"star": {"ab*def", true},
|
||||
"left paren": {"ab(def", true},
|
||||
"right paren": {"ab)def", true},
|
||||
"dash": {"ab-def", true},
|
||||
"digit": {"a123ef", false},
|
||||
"dollar end": {"abcde$", false},
|
||||
"dollar middle": {"ab$def", false},
|
||||
"dollar start": {"$bcdef", true},
|
||||
"backtick prefix": {"`bcdef", true},
|
||||
"backtick middle": {"ab`def", true},
|
||||
"backtick suffix": {"abcde`", true},
|
||||
"single quote prefix": {"'bcdef", true},
|
||||
"single quote middle": {"ab'def", true},
|
||||
"single quote suffix": {"abcde'", true},
|
||||
"double quote prefix": {`"bcdef`, true},
|
||||
"double quote middle": {`ab"def`, true},
|
||||
"double quote suffix": {`abcde"`, true},
|
||||
"underscore with all runes": {"_bcd123__a__$", false},
|
||||
"all runes": {"abcd123__a__$", false},
|
||||
"default table name": {defaultTableName, false},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
err := validateDBTable(test.table)
|
||||
if test.expectErr && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Fatalf("no error expected, got: %s", err)
|
||||
}
|
||||
})
|
||||
t.Run(fmt.Sprintf("database: %s", name), func(t *testing.T) {
|
||||
dbTable := fmt.Sprintf("%s.%s", test.table, test.table)
|
||||
err := validateDBTable(dbTable)
|
||||
if test.expectErr && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Fatalf("no error expected, got: %s", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,441 +0,0 @@
|
|||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
package cockroachdb
|
||||
|
||||
// sqlKeywords is a reference of all of the keywords that we do not allow for use as the table name
|
||||
// Referenced from:
|
||||
// https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#identifiers
|
||||
// -> https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords
|
||||
// -> https://www.cockroachlabs.com/docs/stable/sql-grammar.html
|
||||
var sqlKeywords = map[string]bool{
|
||||
// reserved_keyword
|
||||
// https://www.cockroachlabs.com/docs/stable/sql-grammar.html#reserved_keyword
|
||||
"ALL": true,
|
||||
"ANALYSE": true,
|
||||
"ANALYZE": true,
|
||||
"AND": true,
|
||||
"ANY": true,
|
||||
"ARRAY": true,
|
||||
"AS": true,
|
||||
"ASC": true,
|
||||
"ASYMMETRIC": true,
|
||||
"BOTH": true,
|
||||
"CASE": true,
|
||||
"CAST": true,
|
||||
"CHECK": true,
|
||||
"COLLATE": true,
|
||||
"COLUMN": true,
|
||||
"CONCURRENTLY": true,
|
||||
"CONSTRAINT": true,
|
||||
"CREATE": true,
|
||||
"CURRENT_CATALOG": true,
|
||||
"CURRENT_DATE": true,
|
||||
"CURRENT_ROLE": true,
|
||||
"CURRENT_SCHEMA": true,
|
||||
"CURRENT_TIME": true,
|
||||
"CURRENT_TIMESTAMP": true,
|
||||
"CURRENT_USER": true,
|
||||
"DEFAULT": true,
|
||||
"DEFERRABLE": true,
|
||||
"DESC": true,
|
||||
"DISTINCT": true,
|
||||
"DO": true,
|
||||
"ELSE": true,
|
||||
"END": true,
|
||||
"EXCEPT": true,
|
||||
"FALSE": true,
|
||||
"FETCH": true,
|
||||
"FOR": true,
|
||||
"FOREIGN": true,
|
||||
"FROM": true,
|
||||
"GRANT": true,
|
||||
"GROUP": true,
|
||||
"HAVING": true,
|
||||
"IN": true,
|
||||
"INITIALLY": true,
|
||||
"INTERSECT": true,
|
||||
"INTO": true,
|
||||
"LATERAL": true,
|
||||
"LEADING": true,
|
||||
"LIMIT": true,
|
||||
"LOCALTIME": true,
|
||||
"LOCALTIMESTAMP": true,
|
||||
"NOT": true,
|
||||
"NULL": true,
|
||||
"OFFSET": true,
|
||||
"ON": true,
|
||||
"ONLY": true,
|
||||
"OR": true,
|
||||
"ORDER": true,
|
||||
"PLACING": true,
|
||||
"PRIMARY": true,
|
||||
"REFERENCES": true,
|
||||
"RETURNING": true,
|
||||
"SELECT": true,
|
||||
"SESSION_USER": true,
|
||||
"SOME": true,
|
||||
"SYMMETRIC": true,
|
||||
"TABLE": true,
|
||||
"THEN": true,
|
||||
"TO": true,
|
||||
"TRAILING": true,
|
||||
"TRUE": true,
|
||||
"UNION": true,
|
||||
"UNIQUE": true,
|
||||
"USER": true,
|
||||
"USING": true,
|
||||
"VARIADIC": true,
|
||||
"WHEN": true,
|
||||
"WHERE": true,
|
||||
"WINDOW": true,
|
||||
"WITH": true,
|
||||
|
||||
// cockroachdb_extra_reserved_keyword
|
||||
// https://www.cockroachlabs.com/docs/stable/sql-grammar.html#cockroachdb_extra_reserved_keyword
|
||||
"INDEX": true,
|
||||
"NOTHING": true,
|
||||
|
||||
// type_func_name_keyword
|
||||
// https://www.cockroachlabs.com/docs/stable/sql-grammar.html#type_func_name_keyword
|
||||
"COLLATION": true,
|
||||
"CROSS": true,
|
||||
"FULL": true,
|
||||
"INNER": true,
|
||||
"ILIKE": true,
|
||||
"IS": true,
|
||||
"ISNULL": true,
|
||||
"JOIN": true,
|
||||
"LEFT": true,
|
||||
"LIKE": true,
|
||||
"NATURAL": true,
|
||||
"NONE": true,
|
||||
"NOTNULL": true,
|
||||
"OUTER": true,
|
||||
"OVERLAPS": true,
|
||||
"RIGHT": true,
|
||||
"SIMILAR": true,
|
||||
"FAMILY": true,
|
||||
|
||||
// col_name_keyword
|
||||
// https://www.cockroachlabs.com/docs/stable/sql-grammar.html#col_name_keyword
|
||||
"ANNOTATE_TYPE": true,
|
||||
"BETWEEN": true,
|
||||
"BIGINT": true,
|
||||
"BIT": true,
|
||||
"BOOLEAN": true,
|
||||
"CHAR": true,
|
||||
"CHARACTER": true,
|
||||
"CHARACTERISTICS": true,
|
||||
"COALESCE": true,
|
||||
"DEC": true,
|
||||
"DECIMAL": true,
|
||||
"EXISTS": true,
|
||||
"EXTRACT": true,
|
||||
"EXTRACT_DURATION": true,
|
||||
"FLOAT": true,
|
||||
"GREATEST": true,
|
||||
"GROUPING": true,
|
||||
"IF": true,
|
||||
"IFERROR": true,
|
||||
"IFNULL": true,
|
||||
"INT": true,
|
||||
"INTEGER": true,
|
||||
"INTERVAL": true,
|
||||
"ISERROR": true,
|
||||
"LEAST": true,
|
||||
"NULLIF": true,
|
||||
"NUMERIC": true,
|
||||
"OUT": true,
|
||||
"OVERLAY": true,
|
||||
"POSITION": true,
|
||||
"PRECISION": true,
|
||||
"REAL": true,
|
||||
"ROW": true,
|
||||
"SMALLINT": true,
|
||||
"SUBSTRING": true,
|
||||
"TIME": true,
|
||||
"TIMETZ": true,
|
||||
"TIMESTAMP": true,
|
||||
"TIMESTAMPTZ": true,
|
||||
"TREAT": true,
|
||||
"TRIM": true,
|
||||
"VALUES": true,
|
||||
"VARBIT": true,
|
||||
"VARCHAR": true,
|
||||
"VIRTUAL": true,
|
||||
"WORK": true,
|
||||
|
||||
// unreserved_keyword
|
||||
// https://www.cockroachlabs.com/docs/stable/sql-grammar.html#unreserved_keyword
|
||||
"ABORT": true,
|
||||
"ACTION": true,
|
||||
"ADD": true,
|
||||
"ADMIN": true,
|
||||
"AGGREGATE": true,
|
||||
"ALTER": true,
|
||||
"AT": true,
|
||||
"AUTOMATIC": true,
|
||||
"AUTHORIZATION": true,
|
||||
"BACKUP": true,
|
||||
"BEGIN": true,
|
||||
"BIGSERIAL": true,
|
||||
"BLOB": true,
|
||||
"BOOL": true,
|
||||
"BUCKET_COUNT": true,
|
||||
"BUNDLE": true,
|
||||
"BY": true,
|
||||
"BYTEA": true,
|
||||
"BYTES": true,
|
||||
"CACHE": true,
|
||||
"CANCEL": true,
|
||||
"CASCADE": true,
|
||||
"CHANGEFEED": true,
|
||||
"CLUSTER": true,
|
||||
"COLUMNS": true,
|
||||
"COMMENT": true,
|
||||
"COMMIT": true,
|
||||
"COMMITTED": true,
|
||||
"COMPACT": true,
|
||||
"COMPLETE": true,
|
||||
"CONFLICT": true,
|
||||
"CONFIGURATION": true,
|
||||
"CONFIGURATIONS": true,
|
||||
"CONFIGURE": true,
|
||||
"CONSTRAINTS": true,
|
||||
"CONVERSION": true,
|
||||
"COPY": true,
|
||||
"COVERING": true,
|
||||
"CREATEROLE": true,
|
||||
"CUBE": true,
|
||||
"CURRENT": true,
|
||||
"CYCLE": true,
|
||||
"DATA": true,
|
||||
"DATABASE": true,
|
||||
"DATABASES": true,
|
||||
"DATE": true,
|
||||
"DAY": true,
|
||||
"DEALLOCATE": true,
|
||||
"DELETE": true,
|
||||
"DEFERRED": true,
|
||||
"DISCARD": true,
|
||||
"DOMAIN": true,
|
||||
"DOUBLE": true,
|
||||
"DROP": true,
|
||||
"ENCODING": true,
|
||||
"ENUM": true,
|
||||
"ESCAPE": true,
|
||||
"EXCLUDE": true,
|
||||
"EXECUTE": true,
|
||||
"EXPERIMENTAL": true,
|
||||
"EXPERIMENTAL_AUDIT": true,
|
||||
"EXPERIMENTAL_FINGERPRINTS": true,
|
||||
"EXPERIMENTAL_RELOCATE": true,
|
||||
"EXPERIMENTAL_REPLICA": true,
|
||||
"EXPIRATION": true,
|
||||
"EXPLAIN": true,
|
||||
"EXPORT": true,
|
||||
"EXTENSION": true,
|
||||
"FILES": true,
|
||||
"FILTER": true,
|
||||
"FIRST": true,
|
||||
"FLOAT4": true,
|
||||
"FLOAT8": true,
|
||||
"FOLLOWING": true,
|
||||
"FORCE_INDEX": true,
|
||||
"FUNCTION": true,
|
||||
"GLOBAL": true,
|
||||
"GRANTS": true,
|
||||
"GROUPS": true,
|
||||
"HASH": true,
|
||||
"HIGH": true,
|
||||
"HISTOGRAM": true,
|
||||
"HOUR": true,
|
||||
"IMMEDIATE": true,
|
||||
"IMPORT": true,
|
||||
"INCLUDE": true,
|
||||
"INCREMENT": true,
|
||||
"INCREMENTAL": true,
|
||||
"INDEXES": true,
|
||||
"INET": true,
|
||||
"INJECT": true,
|
||||
"INSERT": true,
|
||||
"INT2": true,
|
||||
"INT2VECTOR": true,
|
||||
"INT4": true,
|
||||
"INT8": true,
|
||||
"INT64": true,
|
||||
"INTERLEAVE": true,
|
||||
"INVERTED": true,
|
||||
"ISOLATION": true,
|
||||
"JOB": true,
|
||||
"JOBS": true,
|
||||
"JSON": true,
|
||||
"JSONB": true,
|
||||
"KEY": true,
|
||||
"KEYS": true,
|
||||
"KV": true,
|
||||
"LANGUAGE": true,
|
||||
"LAST": true,
|
||||
"LC_COLLATE": true,
|
||||
"LC_CTYPE": true,
|
||||
"LEASE": true,
|
||||
"LESS": true,
|
||||
"LEVEL": true,
|
||||
"LIST": true,
|
||||
"LOCAL": true,
|
||||
"LOCKED": true,
|
||||
"LOGIN": true,
|
||||
"LOOKUP": true,
|
||||
"LOW": true,
|
||||
"MATCH": true,
|
||||
"MATERIALIZED": true,
|
||||
"MAXVALUE": true,
|
||||
"MERGE": true,
|
||||
"MINUTE": true,
|
||||
"MINVALUE": true,
|
||||
"MONTH": true,
|
||||
"NAMES": true,
|
||||
"NAN": true,
|
||||
"NAME": true,
|
||||
"NEXT": true,
|
||||
"NO": true,
|
||||
"NORMAL": true,
|
||||
"NO_INDEX_JOIN": true,
|
||||
"NOCREATEROLE": true,
|
||||
"NOLOGIN": true,
|
||||
"NOWAIT": true,
|
||||
"NULLS": true,
|
||||
"IGNORE_FOREIGN_KEYS": true,
|
||||
"OF": true,
|
||||
"OFF": true,
|
||||
"OID": true,
|
||||
"OIDS": true,
|
||||
"OIDVECTOR": true,
|
||||
"OPERATOR": true,
|
||||
"OPT": true,
|
||||
"OPTION": true,
|
||||
"OPTIONS": true,
|
||||
"ORDINALITY": true,
|
||||
"OTHERS": true,
|
||||
"OVER": true,
|
||||
"OWNED": true,
|
||||
"PARENT": true,
|
||||
"PARTIAL": true,
|
||||
"PARTITION": true,
|
||||
"PARTITIONS": true,
|
||||
"PASSWORD": true,
|
||||
"PAUSE": true,
|
||||
"PHYSICAL": true,
|
||||
"PLAN": true,
|
||||
"PLANS": true,
|
||||
"PRECEDING": true,
|
||||
"PREPARE": true,
|
||||
"PRESERVE": true,
|
||||
"PRIORITY": true,
|
||||
"PUBLIC": true,
|
||||
"PUBLICATION": true,
|
||||
"QUERIES": true,
|
||||
"QUERY": true,
|
||||
"RANGE": true,
|
||||
"RANGES": true,
|
||||
"READ": true,
|
||||
"RECURSIVE": true,
|
||||
"REF": true,
|
||||
"REGCLASS": true,
|
||||
"REGPROC": true,
|
||||
"REGPROCEDURE": true,
|
||||
"REGNAMESPACE": true,
|
||||
"REGTYPE": true,
|
||||
"REINDEX": true,
|
||||
"RELEASE": true,
|
||||
"RENAME": true,
|
||||
"REPEATABLE": true,
|
||||
"REPLACE": true,
|
||||
"RESET": true,
|
||||
"RESTORE": true,
|
||||
"RESTRICT": true,
|
||||
"RESUME": true,
|
||||
"REVOKE": true,
|
||||
"ROLE": true,
|
||||
"ROLES": true,
|
||||
"ROLLBACK": true,
|
||||
"ROLLUP": true,
|
||||
"ROWS": true,
|
||||
"RULE": true,
|
||||
"SETTING": true,
|
||||
"SETTINGS": true,
|
||||
"STATUS": true,
|
||||
"SAVEPOINT": true,
|
||||
"SCATTER": true,
|
||||
"SCHEMA": true,
|
||||
"SCHEMAS": true,
|
||||
"SCRUB": true,
|
||||
"SEARCH": true,
|
||||
"SECOND": true,
|
||||
"SERIAL": true,
|
||||
"SERIALIZABLE": true,
|
||||
"SERIAL2": true,
|
||||
"SERIAL4": true,
|
||||
"SERIAL8": true,
|
||||
"SEQUENCE": true,
|
||||
"SEQUENCES": true,
|
||||
"SERVER": true,
|
||||
"SESSION": true,
|
||||
"SESSIONS": true,
|
||||
"SET": true,
|
||||
"SHARE": true,
|
||||
"SHOW": true,
|
||||
"SIMPLE": true,
|
||||
"SKIP": true,
|
||||
"SMALLSERIAL": true,
|
||||
"SNAPSHOT": true,
|
||||
"SPLIT": true,
|
||||
"SQL": true,
|
||||
"START": true,
|
||||
"STATISTICS": true,
|
||||
"STDIN": true,
|
||||
"STORE": true,
|
||||
"STORED": true,
|
||||
"STORING": true,
|
||||
"STRICT": true,
|
||||
"STRING": true,
|
||||
"SUBSCRIPTION": true,
|
||||
"SYNTAX": true,
|
||||
"SYSTEM": true,
|
||||
"TABLES": true,
|
||||
"TEMP": true,
|
||||
"TEMPLATE": true,
|
||||
"TEMPORARY": true,
|
||||
"TESTING_RELOCATE": true,
|
||||
"TEXT": true,
|
||||
"TIES": true,
|
||||
"TRACE": true,
|
||||
"TRANSACTION": true,
|
||||
"TRIGGER": true,
|
||||
"TRUNCATE": true,
|
||||
"TRUSTED": true,
|
||||
"TYPE": true,
|
||||
"THROTTLING": true,
|
||||
"UNBOUNDED": true,
|
||||
"UNCOMMITTED": true,
|
||||
"UNKNOWN": true,
|
||||
"UNLOGGED": true,
|
||||
"UNSPLIT": true,
|
||||
"UNTIL": true,
|
||||
"UPDATE": true,
|
||||
"UPSERT": true,
|
||||
"UUID": true,
|
||||
"USE": true,
|
||||
"USERS": true,
|
||||
"VALID": true,
|
||||
"VALIDATE": true,
|
||||
"VALUE": true,
|
||||
"VARYING": true,
|
||||
"VIEW": true,
|
||||
"WITHIN": true,
|
||||
"WITHOUT": true,
|
||||
"WRITE": true,
|
||||
"YEAR": true,
|
||||
"ZONE": true,
|
||||
}
|
|
@ -1,779 +0,0 @@
|
|||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
|
||||
metrics "github.com/armon/go-metrics"
|
||||
mysql "github.com/go-sql-driver/mysql"
|
||||
"github.com/hashicorp/go-secure-stdlib/strutil"
|
||||
"github.com/hashicorp/vault/sdk/physical"
|
||||
)
|
||||
|
||||
// Verify MySQLBackend satisfies the correct interfaces
|
||||
var (
|
||||
_ physical.Backend = (*MySQLBackend)(nil)
|
||||
_ physical.HABackend = (*MySQLBackend)(nil)
|
||||
_ physical.Lock = (*MySQLHALock)(nil)
|
||||
)
|
||||
|
||||
// Unreserved tls key
|
||||
// Reserved values are "true", "false", "skip-verify"
|
||||
const mysqlTLSKey = "default"
|
||||
|
||||
// MySQLBackend is a physical backend that stores data
|
||||
// within MySQL database.
|
||||
type MySQLBackend struct {
|
||||
dbTable string
|
||||
dbLockTable string
|
||||
client *sql.DB
|
||||
statements map[string]*sql.Stmt
|
||||
logger log.Logger
|
||||
permitPool *physical.PermitPool
|
||||
conf map[string]string
|
||||
redirectHost string
|
||||
redirectPort int64
|
||||
haEnabled bool
|
||||
}
|
||||
|
||||
// NewMySQLBackend constructs a MySQL backend using the given API client and
|
||||
// server address and credential for accessing mysql database.
|
||||
func NewMySQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
|
||||
var err error
|
||||
|
||||
db, err := NewMySQLClient(conf, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
database := conf["database"]
|
||||
if database == "" {
|
||||
database = "vault"
|
||||
}
|
||||
table := conf["table"]
|
||||
if table == "" {
|
||||
table = "vault"
|
||||
}
|
||||
|
||||
err = validateDBTable(database, table)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dbTable := fmt.Sprintf("`%s`.`%s`", database, table)
|
||||
|
||||
maxParStr, ok := conf["max_parallel"]
|
||||
var maxParInt int
|
||||
if ok {
|
||||
maxParInt, err = strconv.Atoi(maxParStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err)
|
||||
}
|
||||
if logger.IsDebug() {
|
||||
logger.Debug("max_parallel set", "max_parallel", maxParInt)
|
||||
}
|
||||
} else {
|
||||
maxParInt = physical.DefaultParallelOperations
|
||||
}
|
||||
|
||||
// Check schema exists
|
||||
var schemaExist bool
|
||||
schemaRows, err := db.Query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?", database)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check mysql schema exist: %w", err)
|
||||
}
|
||||
defer schemaRows.Close()
|
||||
schemaExist = schemaRows.Next()
|
||||
|
||||
// Check table exists
|
||||
var tableExist bool
|
||||
tableRows, err := db.Query("SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_NAME = ? AND TABLE_SCHEMA = ?", table, database)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check mysql table exist: %w", err)
|
||||
}
|
||||
defer tableRows.Close()
|
||||
tableExist = tableRows.Next()
|
||||
|
||||
// Create the required database if it doesn't exists.
|
||||
if !schemaExist {
|
||||
if _, err := db.Exec("CREATE DATABASE IF NOT EXISTS `" + database + "`"); err != nil {
|
||||
return nil, fmt.Errorf("failed to create mysql database: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create the required table if it doesn't exists.
|
||||
if !tableExist {
|
||||
create_query := "CREATE TABLE IF NOT EXISTS " + dbTable +
|
||||
" (vault_key varbinary(3072), vault_value mediumblob, PRIMARY KEY (vault_key))"
|
||||
if _, err := db.Exec(create_query); err != nil {
|
||||
return nil, fmt.Errorf("failed to create mysql table: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Default value for ha_enabled
|
||||
haEnabledStr, ok := conf["ha_enabled"]
|
||||
if !ok {
|
||||
haEnabledStr = "false"
|
||||
}
|
||||
haEnabled, err := strconv.ParseBool(haEnabledStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("value [%v] of 'ha_enabled' could not be understood", haEnabledStr)
|
||||
}
|
||||
|
||||
locktable, ok := conf["lock_table"]
|
||||
if !ok {
|
||||
locktable = table + "_lock"
|
||||
}
|
||||
|
||||
dbLockTable := "`" + database + "`.`" + locktable + "`"
|
||||
|
||||
// Only create lock table if ha_enabled is true
|
||||
if haEnabled {
|
||||
// Check table exists
|
||||
var lockTableExist bool
|
||||
lockTableRows, err := db.Query("SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_NAME = ? AND TABLE_SCHEMA = ?", locktable, database)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check mysql table exist: %w", err)
|
||||
}
|
||||
defer lockTableRows.Close()
|
||||
lockTableExist = lockTableRows.Next()
|
||||
|
||||
// Create the required table if it doesn't exists.
|
||||
if !lockTableExist {
|
||||
create_query := "CREATE TABLE IF NOT EXISTS " + dbLockTable +
|
||||
" (node_job varbinary(512), current_leader varbinary(512), PRIMARY KEY (node_job))"
|
||||
if _, err := db.Exec(create_query); err != nil {
|
||||
return nil, fmt.Errorf("failed to create mysql table: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Setup the backend.
|
||||
m := &MySQLBackend{
|
||||
dbTable: dbTable,
|
||||
dbLockTable: dbLockTable,
|
||||
client: db,
|
||||
statements: make(map[string]*sql.Stmt),
|
||||
logger: logger,
|
||||
permitPool: physical.NewPermitPool(maxParInt),
|
||||
conf: conf,
|
||||
haEnabled: haEnabled,
|
||||
}
|
||||
|
||||
// Prepare all the statements required
|
||||
statements := map[string]string{
|
||||
"put": "INSERT INTO " + dbTable +
|
||||
" VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)",
|
||||
"get": "SELECT vault_value FROM " + dbTable + " WHERE vault_key = ?",
|
||||
"delete": "DELETE FROM " + dbTable + " WHERE vault_key = ?",
|
||||
"list": "SELECT vault_key FROM " + dbTable + " WHERE vault_key LIKE ?",
|
||||
}
|
||||
|
||||
// Only prepare ha-related statements if we need them
|
||||
if haEnabled {
|
||||
statements["get_lock"] = "SELECT current_leader FROM " + dbLockTable + " WHERE node_job = ?"
|
||||
statements["used_lock"] = "SELECT IS_USED_LOCK(?)"
|
||||
}
|
||||
|
||||
for name, query := range statements {
|
||||
if err := m.prepare(name, query); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// validateDBTable to prevent SQL injection attacks. This ensures that the database and table names only have valid
|
||||
// characters in them. MySQL allows for more characters that this will allow, but there isn't an easy way of
|
||||
// representing the full Unicode Basic Multilingual Plane to check against.
|
||||
// https://dev.mysql.com/doc/refman/5.7/en/identifiers.html
|
||||
func validateDBTable(db, table string) (err error) {
|
||||
merr := &multierror.Error{}
|
||||
merr = multierror.Append(merr, wrapErr("invalid database: %w", validate(db)))
|
||||
merr = multierror.Append(merr, wrapErr("invalid table: %w", validate(table)))
|
||||
return merr.ErrorOrNil()
|
||||
}
|
||||
|
||||
func validate(name string) (err error) {
|
||||
if name == "" {
|
||||
return fmt.Errorf("missing name")
|
||||
}
|
||||
// From: https://dev.mysql.com/doc/refman/5.7/en/identifiers.html
|
||||
// - Permitted characters in quoted identifiers include the full Unicode Basic Multilingual Plane (BMP), except U+0000:
|
||||
// ASCII: U+0001 .. U+007F
|
||||
// Extended: U+0080 .. U+FFFF
|
||||
// - ASCII NUL (U+0000) and supplementary characters (U+10000 and higher) are not permitted in quoted or unquoted identifiers.
|
||||
// - Identifiers may begin with a digit but unless quoted may not consist solely of digits.
|
||||
// - Database, table, and column names cannot end with space characters.
|
||||
//
|
||||
// We are explicitly excluding all space characters (it's easier to deal with)
|
||||
// The name will be quoted, so the all-digit requirement doesn't apply
|
||||
runes := []rune(name)
|
||||
validationErr := fmt.Errorf("invalid character found: can only include printable, non-space characters between [0x0001-0xFFFF]")
|
||||
for _, r := range runes {
|
||||
// U+0000 Explicitly disallowed
|
||||
if r == 0x0000 {
|
||||
return fmt.Errorf("invalid character: cannot include 0x0000")
|
||||
}
|
||||
// Cannot be above 0xFFFF
|
||||
if r > 0xFFFF {
|
||||
return fmt.Errorf("invalid character: cannot include any characters above 0xFFFF")
|
||||
}
|
||||
if r == '`' {
|
||||
return fmt.Errorf("invalid character: cannot include '`' character")
|
||||
}
|
||||
if r == '\'' || r == '"' {
|
||||
return fmt.Errorf("invalid character: cannot include quotes")
|
||||
}
|
||||
// We are excluding non-printable characters (not mentioned in the docs)
|
||||
if !unicode.IsPrint(r) {
|
||||
return validationErr
|
||||
}
|
||||
// We are excluding space characters (not mentioned in the docs)
|
||||
if unicode.IsSpace(r) {
|
||||
return validationErr
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func wrapErr(message string, err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf(message, err)
|
||||
}
|
||||
|
||||
func NewMySQLClient(conf map[string]string, logger log.Logger) (*sql.DB, error) {
|
||||
var err error
|
||||
|
||||
// Get the MySQL credentials to perform read/write operations.
|
||||
username, ok := conf["username"]
|
||||
if !ok || username == "" {
|
||||
return nil, fmt.Errorf("missing username")
|
||||
}
|
||||
password, ok := conf["password"]
|
||||
if !ok || password == "" {
|
||||
return nil, fmt.Errorf("missing password")
|
||||
}
|
||||
|
||||
// Get or set MySQL server address. Defaults to localhost and default port(3306)
|
||||
address, ok := conf["address"]
|
||||
if !ok {
|
||||
address = "127.0.0.1:3306"
|
||||
}
|
||||
|
||||
maxIdleConnStr, ok := conf["max_idle_connections"]
|
||||
var maxIdleConnInt int
|
||||
if ok {
|
||||
maxIdleConnInt, err = strconv.Atoi(maxIdleConnStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed parsing max_idle_connections parameter: %w", err)
|
||||
}
|
||||
if logger.IsDebug() {
|
||||
logger.Debug("max_idle_connections set", "max_idle_connections", maxIdleConnInt)
|
||||
}
|
||||
}
|
||||
|
||||
maxConnLifeStr, ok := conf["max_connection_lifetime"]
|
||||
var maxConnLifeInt int
|
||||
if ok {
|
||||
maxConnLifeInt, err = strconv.Atoi(maxConnLifeStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed parsing max_connection_lifetime parameter: %w", err)
|
||||
}
|
||||
if logger.IsDebug() {
|
||||
logger.Debug("max_connection_lifetime set", "max_connection_lifetime", maxConnLifeInt)
|
||||
}
|
||||
}
|
||||
|
||||
maxParStr, ok := conf["max_parallel"]
|
||||
var maxParInt int
|
||||
if ok {
|
||||
maxParInt, err = strconv.Atoi(maxParStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err)
|
||||
}
|
||||
if logger.IsDebug() {
|
||||
logger.Debug("max_parallel set", "max_parallel", maxParInt)
|
||||
}
|
||||
} else {
|
||||
maxParInt = physical.DefaultParallelOperations
|
||||
}
|
||||
|
||||
dsnParams := url.Values{}
|
||||
tlsCaFile, tlsOk := conf["tls_ca_file"]
|
||||
if tlsOk {
|
||||
if err := setupMySQLTLSConfig(tlsCaFile); err != nil {
|
||||
return nil, fmt.Errorf("failed register TLS config: %w", err)
|
||||
}
|
||||
|
||||
dsnParams.Add("tls", mysqlTLSKey)
|
||||
}
|
||||
ptAllowed, ptOk := conf["plaintext_connection_allowed"]
|
||||
if !(ptOk && strings.ToLower(ptAllowed) == "true") && !tlsOk {
|
||||
logger.Warn("No TLS specified, credentials will be sent in plaintext. To mute this warning add 'plaintext_connection_allowed' with a true value to your MySQL configuration in your config file.")
|
||||
}
|
||||
|
||||
// Create MySQL handle for the database.
|
||||
dsn := username + ":" + password + "@tcp(" + address + ")/?" + dsnParams.Encode()
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to mysql: %w", err)
|
||||
}
|
||||
db.SetMaxOpenConns(maxParInt)
|
||||
if maxIdleConnInt != 0 {
|
||||
db.SetMaxIdleConns(maxIdleConnInt)
|
||||
}
|
||||
if maxConnLifeInt != 0 {
|
||||
db.SetConnMaxLifetime(time.Duration(maxConnLifeInt) * time.Second)
|
||||
}
|
||||
|
||||
return db, err
|
||||
}
|
||||
|
||||
// prepare is a helper to prepare a query for future execution
|
||||
func (m *MySQLBackend) prepare(name, query string) error {
|
||||
stmt, err := m.client.Prepare(query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prepare %q: %w", name, err)
|
||||
}
|
||||
m.statements[name] = stmt
|
||||
return nil
|
||||
}
|
||||
|
||||
// Put is used to insert or update an entry.
|
||||
func (m *MySQLBackend) Put(ctx context.Context, entry *physical.Entry) error {
|
||||
defer metrics.MeasureSince([]string{"mysql", "put"}, time.Now())
|
||||
|
||||
m.permitPool.Acquire()
|
||||
defer m.permitPool.Release()
|
||||
|
||||
_, err := m.statements["put"].Exec(entry.Key, entry.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get is used to fetch an entry.
|
||||
func (m *MySQLBackend) Get(ctx context.Context, key string) (*physical.Entry, error) {
|
||||
defer metrics.MeasureSince([]string{"mysql", "get"}, time.Now())
|
||||
|
||||
m.permitPool.Acquire()
|
||||
defer m.permitPool.Release()
|
||||
|
||||
var result []byte
|
||||
err := m.statements["get"].QueryRow(key).Scan(&result)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ent := &physical.Entry{
|
||||
Key: key,
|
||||
Value: result,
|
||||
}
|
||||
return ent, nil
|
||||
}
|
||||
|
||||
// Delete is used to permanently delete an entry
|
||||
func (m *MySQLBackend) Delete(ctx context.Context, key string) error {
|
||||
defer metrics.MeasureSince([]string{"mysql", "delete"}, time.Now())
|
||||
|
||||
m.permitPool.Acquire()
|
||||
defer m.permitPool.Release()
|
||||
|
||||
_, err := m.statements["delete"].Exec(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// List is used to list all the keys under a given
|
||||
// prefix, up to the next prefix.
|
||||
func (m *MySQLBackend) List(ctx context.Context, prefix string) ([]string, error) {
|
||||
defer metrics.MeasureSince([]string{"mysql", "list"}, time.Now())
|
||||
|
||||
m.permitPool.Acquire()
|
||||
defer m.permitPool.Release()
|
||||
|
||||
// Add the % wildcard to the prefix to do the prefix search
|
||||
likePrefix := prefix + "%"
|
||||
rows, err := m.statements["list"].Query(likePrefix)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute statement: %w", err)
|
||||
}
|
||||
|
||||
var keys []string
|
||||
for rows.Next() {
|
||||
var key string
|
||||
err = rows.Scan(&key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan rows: %w", err)
|
||||
}
|
||||
|
||||
key = strings.TrimPrefix(key, prefix)
|
||||
if i := strings.Index(key, "/"); i == -1 {
|
||||
// Add objects only from the current 'folder'
|
||||
keys = append(keys, key)
|
||||
} else if i != -1 {
|
||||
// Add truncated 'folder' paths
|
||||
keys = strutil.AppendIfMissing(keys, string(key[:i+1]))
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(keys)
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// LockWith is used for mutual exclusion based on the given key.
|
||||
func (m *MySQLBackend) LockWith(key, value string) (physical.Lock, error) {
|
||||
l := &MySQLHALock{
|
||||
in: m,
|
||||
key: key,
|
||||
value: value,
|
||||
logger: m.logger,
|
||||
}
|
||||
return l, nil
|
||||
}
|
||||
|
||||
func (m *MySQLBackend) HAEnabled() bool {
|
||||
return m.haEnabled
|
||||
}
|
||||
|
||||
// MySQLHALock is a MySQL Lock implementation for the HABackend
|
||||
type MySQLHALock struct {
|
||||
in *MySQLBackend
|
||||
key string
|
||||
value string
|
||||
logger log.Logger
|
||||
|
||||
held bool
|
||||
localLock sync.Mutex
|
||||
leaderCh chan struct{}
|
||||
stopCh <-chan struct{}
|
||||
lock *MySQLLock
|
||||
}
|
||||
|
||||
func (i *MySQLHALock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) {
|
||||
i.localLock.Lock()
|
||||
defer i.localLock.Unlock()
|
||||
if i.held {
|
||||
return nil, fmt.Errorf("lock already held")
|
||||
}
|
||||
|
||||
// Attempt an async acquisition
|
||||
didLock := make(chan struct{})
|
||||
failLock := make(chan error, 1)
|
||||
releaseCh := make(chan bool, 1)
|
||||
go i.attemptLock(i.key, i.value, didLock, failLock, releaseCh)
|
||||
|
||||
// Wait for lock acquisition, failure, or shutdown
|
||||
select {
|
||||
case <-didLock:
|
||||
releaseCh <- false
|
||||
case err := <-failLock:
|
||||
return nil, err
|
||||
case <-stopCh:
|
||||
releaseCh <- true
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Create the leader channel
|
||||
i.held = true
|
||||
i.leaderCh = make(chan struct{})
|
||||
|
||||
go i.monitorLock(i.leaderCh)
|
||||
|
||||
i.stopCh = stopCh
|
||||
|
||||
return i.leaderCh, nil
|
||||
}
|
||||
|
||||
func (i *MySQLHALock) attemptLock(key, value string, didLock chan struct{}, failLock chan error, releaseCh chan bool) {
|
||||
lock, err := NewMySQLLock(i.in, i.logger, key, value)
|
||||
if err != nil {
|
||||
failLock <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Set node value
|
||||
i.lock = lock
|
||||
|
||||
err = lock.Lock()
|
||||
if err != nil {
|
||||
failLock <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Signal that lock is held
|
||||
close(didLock)
|
||||
|
||||
// Handle an early abort
|
||||
release := <-releaseCh
|
||||
if release {
|
||||
lock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (i *MySQLHALock) monitorLock(leaderCh chan struct{}) {
|
||||
for {
|
||||
// The only way to lose this lock is if someone is
|
||||
// logging into the DB and altering system tables or you lose a connection in
|
||||
// which case you will lose the lock anyway.
|
||||
err := i.hasLock(i.key)
|
||||
if err != nil {
|
||||
// Somehow we lost the lock.... likely because the connection holding
|
||||
// the lock was closed or someone was playing around with the locks in the DB.
|
||||
close(leaderCh)
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func (i *MySQLHALock) Unlock() error {
|
||||
i.localLock.Lock()
|
||||
defer i.localLock.Unlock()
|
||||
if !i.held {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := i.lock.Unlock()
|
||||
|
||||
if err == nil {
|
||||
i.held = false
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// hasLock will check if a lock is held by checking the current lock id against our known ID.
|
||||
func (i *MySQLHALock) hasLock(key string) error {
|
||||
var result sql.NullInt64
|
||||
err := i.in.statements["used_lock"].QueryRow(key).Scan(&result)
|
||||
if err == sql.ErrNoRows || !result.Valid {
|
||||
// This is not an error to us since it just means the lock isn't held
|
||||
return nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// IS_USED_LOCK will return the ID of the connection that created the lock.
|
||||
if result.Int64 != GlobalLockID {
|
||||
return ErrLockHeld
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *MySQLHALock) GetLeader() (string, error) {
|
||||
defer metrics.MeasureSince([]string{"mysql", "lock_get"}, time.Now())
|
||||
var result string
|
||||
err := i.in.statements["get_lock"].QueryRow("leader").Scan(&result)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (i *MySQLHALock) Value() (bool, string, error) {
|
||||
leaderkey, err := i.GetLeader()
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
return true, leaderkey, err
|
||||
}
|
||||
|
||||
// MySQLLock provides an easy way to grab and release mysql
|
||||
// locks using the built in GET_LOCK function. Note that these
|
||||
// locks are released when you lose connection to the server.
|
||||
type MySQLLock struct {
|
||||
parentConn *MySQLBackend
|
||||
in *sql.DB
|
||||
logger log.Logger
|
||||
statements map[string]*sql.Stmt
|
||||
key string
|
||||
value string
|
||||
}
|
||||
|
||||
// Errors specific to trying to grab a lock in MySQL
|
||||
var (
|
||||
// This is the GlobalLockID for checking if the lock we got is still the current lock
|
||||
GlobalLockID int64
|
||||
// ErrLockHeld is returned when another vault instance already has a lock held for the given key.
|
||||
ErrLockHeld = errors.New("mysql: lock already held")
|
||||
// ErrUnlockFailed
|
||||
ErrUnlockFailed = errors.New("mysql: unable to release lock, already released or not held by this session")
|
||||
// You were unable to update that you are the new leader in the DB
|
||||
ErrClaimFailed = errors.New("mysql: unable to update DB with new leader information")
|
||||
// Error to throw if between getting the lock and checking the ID of it we lost it.
|
||||
ErrSettingGlobalID = errors.New("mysql: getting global lock id failed")
|
||||
)
|
||||
|
||||
// NewMySQLLock helper function
|
||||
func NewMySQLLock(in *MySQLBackend, l log.Logger, key, value string) (*MySQLLock, error) {
|
||||
// Create a new MySQL connection so we can close this and have no effect on
|
||||
// the rest of the MySQL backend and any cleanup that might need to be done.
|
||||
conn, _ := NewMySQLClient(in.conf, in.logger)
|
||||
|
||||
m := &MySQLLock{
|
||||
parentConn: in,
|
||||
in: conn,
|
||||
logger: l,
|
||||
statements: make(map[string]*sql.Stmt),
|
||||
key: key,
|
||||
value: value,
|
||||
}
|
||||
|
||||
statements := map[string]string{
|
||||
"put": "INSERT INTO " + in.dbLockTable +
|
||||
" VALUES( ?, ? ) ON DUPLICATE KEY UPDATE current_leader=VALUES(current_leader)",
|
||||
}
|
||||
|
||||
for name, query := range statements {
|
||||
if err := m.prepare(name, query); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// prepare is a helper to prepare a query for future execution
|
||||
func (m *MySQLLock) prepare(name, query string) error {
|
||||
stmt, err := m.in.Prepare(query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prepare %q: %w", name, err)
|
||||
}
|
||||
m.statements[name] = stmt
|
||||
return nil
|
||||
}
|
||||
|
||||
// update the current cluster leader in the DB. This is used so
|
||||
// we can tell the servers in standby who the active leader is.
|
||||
func (i *MySQLLock) becomeLeader() error {
|
||||
_, err := i.statements["put"].Exec("leader", i.value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Lock will try to get a lock for an indefinite amount of time
|
||||
// based on the given key that has been requested.
|
||||
func (i *MySQLLock) Lock() error {
|
||||
defer metrics.MeasureSince([]string{"mysql", "get_lock"}, time.Now())
|
||||
|
||||
// Lock timeout math.MaxInt32 instead of -1 solves compatibility issues with
|
||||
// different MySQL flavours i.e. MariaDB
|
||||
rows, err := i.in.Query("SELECT GET_LOCK(?, ?), IS_USED_LOCK(?)", i.key, math.MaxInt32, i.key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
rows.Next()
|
||||
var lock sql.NullInt64
|
||||
var connectionID sql.NullInt64
|
||||
rows.Scan(&lock, &connectionID)
|
||||
|
||||
if rows.Err() != nil {
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
// 1 is returned from GET_LOCK if it was able to get the lock
|
||||
// 0 if it failed and NULL if some strange error happened.
|
||||
// https://dev.mysql.com/doc/refman/8.0/en/miscellaneous-functions.html#function_get-lock
|
||||
if !lock.Valid || lock.Int64 != 1 {
|
||||
return ErrLockHeld
|
||||
}
|
||||
|
||||
// Since we have the lock alert the rest of the cluster
|
||||
// that we are now the active leader.
|
||||
err = i.becomeLeader()
|
||||
if err != nil {
|
||||
return ErrLockHeld
|
||||
}
|
||||
|
||||
// This will return the connection ID of NULL if an error happens
|
||||
// https://dev.mysql.com/doc/refman/8.0/en/miscellaneous-functions.html#function_is-used-lock
|
||||
if !connectionID.Valid {
|
||||
return ErrSettingGlobalID
|
||||
}
|
||||
|
||||
GlobalLockID = connectionID.Int64
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unlock just closes the connection. This is because closing the MySQL connection
|
||||
// is a 100% reliable way to close the lock. If you just release the lock you must
|
||||
// do it from the same mysql connection_id that you originally created it from. This
|
||||
// is a huge hastle and I actually couldn't find a clean way to do this although one
|
||||
// likely does exist. Closing the connection however ensures we don't ever get into a
|
||||
// state where we try to release the lock and it hangs it is also much less code.
|
||||
func (i *MySQLLock) Unlock() error {
|
||||
err := i.in.Close()
|
||||
if err != nil {
|
||||
return ErrUnlockFailed
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Establish a TLS connection with a given CA certificate
|
||||
// Register a tsl.Config associated with the same key as the dns param from sql.Open
|
||||
// foo:bar@tcp(127.0.0.1:3306)/dbname?tls=default
|
||||
func setupMySQLTLSConfig(tlsCaFile string) error {
|
||||
rootCertPool := x509.NewCertPool()
|
||||
|
||||
pem, err := ioutil.ReadFile(tlsCaFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
|
||||
return err
|
||||
}
|
||||
|
||||
err = mysql.RegisterTLSConfig(mysqlTLSKey, &tls.Config{
|
||||
RootCAs: rootCertPool,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,346 +0,0 @@
|
|||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/sdk/helper/logging"
|
||||
"github.com/hashicorp/vault/sdk/physical"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
mysql "github.com/go-sql-driver/mysql"
|
||||
|
||||
mysqlhelper "github.com/hashicorp/vault/helper/testhelpers/mysql"
|
||||
)
|
||||
|
||||
func TestMySQLPlaintextCatch(t *testing.T) {
|
||||
address := os.Getenv("MYSQL_ADDR")
|
||||
if address == "" {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
database := os.Getenv("MYSQL_DB")
|
||||
if database == "" {
|
||||
database = "test"
|
||||
}
|
||||
|
||||
table := os.Getenv("MYSQL_TABLE")
|
||||
if table == "" {
|
||||
table = "test"
|
||||
}
|
||||
|
||||
username := os.Getenv("MYSQL_USERNAME")
|
||||
password := os.Getenv("MYSQL_PASSWORD")
|
||||
|
||||
// Run vault tests
|
||||
var buf bytes.Buffer
|
||||
log.DefaultOutput = &buf
|
||||
|
||||
logger := logging.NewVaultLogger(log.Debug)
|
||||
|
||||
NewMySQLBackend(map[string]string{
|
||||
"address": address,
|
||||
"database": database,
|
||||
"table": table,
|
||||
"username": username,
|
||||
"password": password,
|
||||
"plaintext_connection_allowed": "false",
|
||||
}, logger)
|
||||
|
||||
str := buf.String()
|
||||
dataIdx := strings.IndexByte(str, ' ')
|
||||
rest := str[dataIdx+1:]
|
||||
|
||||
if !strings.Contains(rest, "credentials will be sent in plaintext") {
|
||||
t.Fatalf("No warning of plaintext credentials occurred")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQLBackend(t *testing.T) {
|
||||
address := os.Getenv("MYSQL_ADDR")
|
||||
if address == "" {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
database := os.Getenv("MYSQL_DB")
|
||||
if database == "" {
|
||||
database = "test"
|
||||
}
|
||||
|
||||
table := os.Getenv("MYSQL_TABLE")
|
||||
if table == "" {
|
||||
table = "test"
|
||||
}
|
||||
|
||||
username := os.Getenv("MYSQL_USERNAME")
|
||||
password := os.Getenv("MYSQL_PASSWORD")
|
||||
|
||||
// Run vault tests
|
||||
logger := logging.NewVaultLogger(log.Debug)
|
||||
|
||||
b, err := NewMySQLBackend(map[string]string{
|
||||
"address": address,
|
||||
"database": database,
|
||||
"table": table,
|
||||
"username": username,
|
||||
"password": password,
|
||||
"plaintext_connection_allowed": "true",
|
||||
"max_connection_lifetime": "1",
|
||||
}, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create new backend: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
mysql := b.(*MySQLBackend)
|
||||
_, err := mysql.client.Exec("DROP TABLE IF EXISTS " + mysql.dbTable + " ," + mysql.dbLockTable)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to drop table: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
physical.ExerciseBackend(t, b)
|
||||
physical.ExerciseBackend_ListPrefix(t, b)
|
||||
}
|
||||
|
||||
func TestMySQLHABackend(t *testing.T) {
|
||||
address := os.Getenv("MYSQL_ADDR")
|
||||
if address == "" {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
database := os.Getenv("MYSQL_DB")
|
||||
if database == "" {
|
||||
database = "test"
|
||||
}
|
||||
|
||||
table := os.Getenv("MYSQL_TABLE")
|
||||
if table == "" {
|
||||
table = "test"
|
||||
}
|
||||
|
||||
username := os.Getenv("MYSQL_USERNAME")
|
||||
password := os.Getenv("MYSQL_PASSWORD")
|
||||
|
||||
// Run vault tests
|
||||
logger := logging.NewVaultLogger(log.Debug)
|
||||
config := map[string]string{
|
||||
"address": address,
|
||||
"database": database,
|
||||
"table": table,
|
||||
"username": username,
|
||||
"password": password,
|
||||
"ha_enabled": "true",
|
||||
"plaintext_connection_allowed": "true",
|
||||
}
|
||||
|
||||
b, err := NewMySQLBackend(config, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create new backend: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
mysql := b.(*MySQLBackend)
|
||||
_, err := mysql.client.Exec("DROP TABLE IF EXISTS " + mysql.dbTable + " ," + mysql.dbLockTable)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to drop table: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
b2, err := NewMySQLBackend(config, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create new backend: %v", err)
|
||||
}
|
||||
|
||||
physical.ExerciseHABackend(t, b.(physical.HABackend), b2.(physical.HABackend))
|
||||
}
|
||||
|
||||
// TestMySQLHABackend_LockFailPanic is a regression test for the panic shown in
|
||||
// https://github.com/hashicorp/vault/issues/8203 and patched in
|
||||
// https://github.com/hashicorp/vault/pull/8229
|
||||
func TestMySQLHABackend_LockFailPanic(t *testing.T) {
|
||||
cleanup, connURL := mysqlhelper.PrepareTestContainer(t, false, "secret")
|
||||
|
||||
cfg, err := mysql.ParseDSN(connURL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := mysqlhelper.TestCredsExist(t, connURL, cfg.User, cfg.Passwd); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
table := "test"
|
||||
logger := logging.NewVaultLogger(log.Debug)
|
||||
config := map[string]string{
|
||||
"address": cfg.Addr,
|
||||
"database": cfg.DBName,
|
||||
"table": table,
|
||||
"username": cfg.User,
|
||||
"password": cfg.Passwd,
|
||||
"ha_enabled": "true",
|
||||
"plaintext_connection_allowed": "true",
|
||||
}
|
||||
|
||||
b, err := NewMySQLBackend(config, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create new backend: %v", err)
|
||||
}
|
||||
|
||||
b2, err := NewMySQLBackend(config, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create new backend: %v", err)
|
||||
}
|
||||
|
||||
b1ha := b.(physical.HABackend)
|
||||
b2ha := b2.(physical.HABackend)
|
||||
|
||||
// Copied from ExerciseHABackend - ensuring things are normal at this point
|
||||
// Get the lock
|
||||
lock, err := b1ha.LockWith("foo", "bar")
|
||||
if err != nil {
|
||||
t.Fatalf("initial lock: %v", err)
|
||||
}
|
||||
|
||||
// Attempt to lock
|
||||
leaderCh, err := lock.Lock(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("lock attempt 1: %v", err)
|
||||
}
|
||||
if leaderCh == nil {
|
||||
t.Fatalf("missing leaderCh")
|
||||
}
|
||||
|
||||
// Check the value
|
||||
held, val, err := lock.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !held {
|
||||
t.Errorf("should be held")
|
||||
}
|
||||
if val != "bar" {
|
||||
t.Errorf("expected value bar: %v", err)
|
||||
}
|
||||
|
||||
// Second acquisition should fail
|
||||
lock2, err := b2ha.LockWith("foo", "baz")
|
||||
if err != nil {
|
||||
t.Fatalf("lock 2: %v", err)
|
||||
}
|
||||
|
||||
stopCh := make(chan struct{})
|
||||
time.AfterFunc(10*time.Second, func() {
|
||||
close(stopCh)
|
||||
})
|
||||
|
||||
// Attempt to lock - can't lock because lock1 is held - this is normal
|
||||
leaderCh2, err := lock2.Lock(stopCh)
|
||||
if err != nil {
|
||||
t.Fatalf("stop lock 2: %v", err)
|
||||
}
|
||||
if leaderCh2 != nil {
|
||||
t.Errorf("should not have gotten leaderCh: %v", leaderCh2)
|
||||
}
|
||||
// end normal
|
||||
|
||||
// Clean up the database. When Lock() is called, a new connection is created
|
||||
// using the configuration. If that connection cannot be created, there was a
|
||||
// panic due to not returning with the connection error. Here we intentionally
|
||||
// break the config for b2, so a new connection can't be made, which would
|
||||
// trigger the panic shown in https://github.com/hashicorp/vault/issues/8203
|
||||
cleanup()
|
||||
|
||||
stopCh2 := make(chan struct{})
|
||||
time.AfterFunc(10*time.Second, func() {
|
||||
close(stopCh2)
|
||||
})
|
||||
leaderCh2, err = lock2.Lock(stopCh2)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got none, leaderCh2=%v", leaderCh2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateDBTable(t *testing.T) {
|
||||
type testCase struct {
|
||||
database string
|
||||
table string
|
||||
expectErr bool
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
"empty database & table": {"", "", true},
|
||||
"empty database": {"", "a", true},
|
||||
"empty table": {"a", "", true},
|
||||
"ascii database": {"abcde", "a", false},
|
||||
"ascii table": {"a", "abcde", false},
|
||||
"ascii database & table": {"abcde", "abcde", false},
|
||||
"only whitespace db": {" ", "a", true},
|
||||
"only whitespace table": {"a", " ", true},
|
||||
"whitespace prefix db": {" bcde", "a", true},
|
||||
"whitespace middle db": {"ab de", "a", true},
|
||||
"whitespace suffix db": {"abcd ", "a", true},
|
||||
"whitespace prefix table": {"a", " bcde", true},
|
||||
"whitespace middle table": {"a", "ab de", true},
|
||||
"whitespace suffix table": {"a", "abcd ", true},
|
||||
"backtick prefix db": {"`bcde", "a", true},
|
||||
"backtick middle db": {"ab`de", "a", true},
|
||||
"backtick suffix db": {"abcd`", "a", true},
|
||||
"backtick prefix table": {"a", "`bcde", true},
|
||||
"backtick middle table": {"a", "ab`de", true},
|
||||
"backtick suffix table": {"a", "abcd`", true},
|
||||
"single quote prefix db": {"'bcde", "a", true},
|
||||
"single quote middle db": {"ab'de", "a", true},
|
||||
"single quote suffix db": {"abcd'", "a", true},
|
||||
"single quote prefix table": {"a", "'bcde", true},
|
||||
"single quote middle table": {"a", "ab'de", true},
|
||||
"single quote suffix table": {"a", "abcd'", true},
|
||||
"double quote prefix db": {`"bcde`, "a", true},
|
||||
"double quote middle db": {`ab"de`, "a", true},
|
||||
"double quote suffix db": {`abcd"`, "a", true},
|
||||
"double quote prefix table": {"a", `"bcde`, true},
|
||||
"double quote middle table": {"a", `ab"de`, true},
|
||||
"double quote suffix table": {"a", `abcd"`, true},
|
||||
"0x0000 prefix db": {str(0x0000, 'b', 'c'), "a", true},
|
||||
"0x0000 middle db": {str('a', 0x0000, 'c'), "a", true},
|
||||
"0x0000 suffix db": {str('a', 'b', 0x0000), "a", true},
|
||||
"0x0000 prefix table": {"a", str(0x0000, 'b', 'c'), true},
|
||||
"0x0000 middle table": {"a", str('a', 0x0000, 'c'), true},
|
||||
"0x0000 suffix table": {"a", str('a', 'b', 0x0000), true},
|
||||
"unicode > 0xFFFF prefix db": {str(0x10000, 'b', 'c'), "a", true},
|
||||
"unicode > 0xFFFF middle db": {str('a', 0x10000, 'c'), "a", true},
|
||||
"unicode > 0xFFFF suffix db": {str('a', 'b', 0x10000), "a", true},
|
||||
"unicode > 0xFFFF prefix table": {"a", str(0x10000, 'b', 'c'), true},
|
||||
"unicode > 0xFFFF middle table": {"a", str('a', 0x10000, 'c'), true},
|
||||
"unicode > 0xFFFF suffix table": {"a", str('a', 'b', 0x10000), true},
|
||||
"non-printable prefix db": {str(0x0001, 'b', 'c'), "a", true},
|
||||
"non-printable middle db": {str('a', 0x0001, 'c'), "a", true},
|
||||
"non-printable suffix db": {str('a', 'b', 0x0001), "a", true},
|
||||
"non-printable prefix table": {"a", str(0x0001, 'b', 'c'), true},
|
||||
"non-printable middle table": {"a", str('a', 0x0001, 'c'), true},
|
||||
"non-printable suffix table": {"a", str('a', 'b', 0x0001), true},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
err := validateDBTable(test.database, test.table)
|
||||
if test.expectErr && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Fatalf("no error expected, got: %s", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func str(r ...rune) string {
|
||||
return string(r)
|
||||
}
|
|
@ -1,474 +0,0 @@
|
|||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/armon/go-metrics"
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
||||
"github.com/hashicorp/vault/sdk/physical"
|
||||
_ "github.com/jackc/pgx/v4/stdlib"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
// The lock TTL matches the default that Consul API uses, 15 seconds.
|
||||
// Used as part of SQL commands to set/extend lock expiry time relative to
|
||||
// database clock.
|
||||
PostgreSQLLockTTLSeconds = 15
|
||||
|
||||
// The amount of time to wait between the lock renewals
|
||||
PostgreSQLLockRenewInterval = 5 * time.Second
|
||||
|
||||
// PostgreSQLLockRetryInterval is the amount of time to wait
|
||||
// if a lock fails before trying again.
|
||||
PostgreSQLLockRetryInterval = time.Second
|
||||
)
|
||||
|
||||
// Verify PostgreSQLBackend satisfies the correct interfaces
|
||||
var _ physical.Backend = (*PostgreSQLBackend)(nil)
|
||||
|
||||
// HA backend was implemented based on the DynamoDB backend pattern
|
||||
// With distinction using central postgres clock, hereby avoiding
|
||||
// possible issues with multiple clocks
|
||||
var (
|
||||
_ physical.HABackend = (*PostgreSQLBackend)(nil)
|
||||
_ physical.Lock = (*PostgreSQLLock)(nil)
|
||||
)
|
||||
|
||||
// PostgreSQL Backend is a physical backend that stores data
|
||||
// within a PostgreSQL database.
|
||||
type PostgreSQLBackend struct {
|
||||
table string
|
||||
client *sql.DB
|
||||
put_query string
|
||||
get_query string
|
||||
delete_query string
|
||||
list_query string
|
||||
|
||||
ha_table string
|
||||
haGetLockValueQuery string
|
||||
haUpsertLockIdentityExec string
|
||||
haDeleteLockExec string
|
||||
|
||||
haEnabled bool
|
||||
logger log.Logger
|
||||
permitPool *physical.PermitPool
|
||||
}
|
||||
|
||||
// PostgreSQLLock implements a lock using an PostgreSQL client.
|
||||
type PostgreSQLLock struct {
|
||||
backend *PostgreSQLBackend
|
||||
value, key string
|
||||
identity string
|
||||
lock sync.Mutex
|
||||
|
||||
renewTicker *time.Ticker
|
||||
|
||||
// ttlSeconds is how long a lock is valid for
|
||||
ttlSeconds int
|
||||
|
||||
// renewInterval is how much time to wait between lock renewals. must be << ttl
|
||||
renewInterval time.Duration
|
||||
|
||||
// retryInterval is how much time to wait between attempts to grab the lock
|
||||
retryInterval time.Duration
|
||||
}
|
||||
|
||||
// NewPostgreSQLBackend constructs a PostgreSQL backend using the given
|
||||
// API client, server address, credentials, and database.
|
||||
func NewPostgreSQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
|
||||
// Get the PostgreSQL credentials to perform read/write operations.
|
||||
connURL := connectionURL(conf)
|
||||
if connURL == "" {
|
||||
return nil, fmt.Errorf("missing connection_url")
|
||||
}
|
||||
|
||||
unquoted_table, ok := conf["table"]
|
||||
if !ok {
|
||||
unquoted_table = "vault_kv_store"
|
||||
}
|
||||
quoted_table := dbutil.QuoteIdentifier(unquoted_table)
|
||||
|
||||
maxParStr, ok := conf["max_parallel"]
|
||||
var maxParInt int
|
||||
var err error
|
||||
if ok {
|
||||
maxParInt, err = strconv.Atoi(maxParStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err)
|
||||
}
|
||||
if logger.IsDebug() {
|
||||
logger.Debug("max_parallel set", "max_parallel", maxParInt)
|
||||
}
|
||||
} else {
|
||||
maxParInt = physical.DefaultParallelOperations
|
||||
}
|
||||
|
||||
maxIdleConnsStr, maxIdleConnsIsSet := conf["max_idle_connections"]
|
||||
var maxIdleConns int
|
||||
if maxIdleConnsIsSet {
|
||||
maxIdleConns, err = strconv.Atoi(maxIdleConnsStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed parsing max_idle_connections parameter: %w", err)
|
||||
}
|
||||
if logger.IsDebug() {
|
||||
logger.Debug("max_idle_connections set", "max_idle_connections", maxIdleConnsStr)
|
||||
}
|
||||
}
|
||||
|
||||
// Create PostgreSQL handle for the database.
|
||||
db, err := sql.Open("pgx", connURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to postgres: %w", err)
|
||||
}
|
||||
db.SetMaxOpenConns(maxParInt)
|
||||
|
||||
if maxIdleConnsIsSet {
|
||||
db.SetMaxIdleConns(maxIdleConns)
|
||||
}
|
||||
|
||||
// Determine if we should use a function to work around lack of upsert (versions < 9.5)
|
||||
var upsertAvailable bool
|
||||
upsertAvailableQuery := "SELECT current_setting('server_version_num')::int >= 90500"
|
||||
if err := db.QueryRow(upsertAvailableQuery).Scan(&upsertAvailable); err != nil {
|
||||
return nil, fmt.Errorf("failed to check for native upsert: %w", err)
|
||||
}
|
||||
|
||||
if !upsertAvailable && conf["ha_enabled"] == "true" {
|
||||
return nil, fmt.Errorf("ha_enabled=true in config but PG version doesn't support HA, must be at least 9.5")
|
||||
}
|
||||
|
||||
// Setup our put strategy based on the presence or absence of a native
|
||||
// upsert.
|
||||
var put_query string
|
||||
if !upsertAvailable {
|
||||
put_query = "SELECT vault_kv_put($1, $2, $3, $4)"
|
||||
} else {
|
||||
put_query = "INSERT INTO " + quoted_table + " VALUES($1, $2, $3, $4)" +
|
||||
" ON CONFLICT (path, key) DO " +
|
||||
" UPDATE SET (parent_path, path, key, value) = ($1, $2, $3, $4)"
|
||||
}
|
||||
|
||||
unquoted_ha_table, ok := conf["ha_table"]
|
||||
if !ok {
|
||||
unquoted_ha_table = "vault_ha_locks"
|
||||
}
|
||||
quoted_ha_table := dbutil.QuoteIdentifier(unquoted_ha_table)
|
||||
|
||||
// Setup the backend.
|
||||
m := &PostgreSQLBackend{
|
||||
table: quoted_table,
|
||||
client: db,
|
||||
put_query: put_query,
|
||||
get_query: "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2",
|
||||
delete_query: "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2",
|
||||
list_query: "SELECT key FROM " + quoted_table + " WHERE path = $1" +
|
||||
" UNION ALL SELECT DISTINCT substring(substr(path, length($1)+1) from '^.*?/') FROM " + quoted_table +
|
||||
" WHERE parent_path LIKE $1 || '%'",
|
||||
haGetLockValueQuery:
|
||||
// only read non expired data
|
||||
" SELECT ha_value FROM " + quoted_ha_table + " WHERE NOW() <= valid_until AND ha_key = $1 ",
|
||||
haUpsertLockIdentityExec:
|
||||
// $1=identity $2=ha_key $3=ha_value $4=TTL in seconds
|
||||
// update either steal expired lock OR update expiry for lock owned by me
|
||||
" INSERT INTO " + quoted_ha_table + " as t (ha_identity, ha_key, ha_value, valid_until) VALUES ($1, $2, $3, NOW() + $4 * INTERVAL '1 seconds' ) " +
|
||||
" ON CONFLICT (ha_key) DO " +
|
||||
" UPDATE SET (ha_identity, ha_key, ha_value, valid_until) = ($1, $2, $3, NOW() + $4 * INTERVAL '1 seconds') " +
|
||||
" WHERE (t.valid_until < NOW() AND t.ha_key = $2) OR " +
|
||||
" (t.ha_identity = $1 AND t.ha_key = $2) ",
|
||||
haDeleteLockExec:
|
||||
// $1=ha_identity $2=ha_key
|
||||
" DELETE FROM " + quoted_ha_table + " WHERE ha_identity=$1 AND ha_key=$2 ",
|
||||
logger: logger,
|
||||
permitPool: physical.NewPermitPool(maxParInt),
|
||||
haEnabled: conf["ha_enabled"] == "true",
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// connectionURL first check the environment variables for a connection URL. If
|
||||
// no connection URL exists in the environment variable, the Vault config file is
|
||||
// checked. If neither the environment variables or the config file set the connection
|
||||
// URL for the Postgres backend, because it is a required field, an error is returned.
|
||||
func connectionURL(conf map[string]string) string {
|
||||
connURL := conf["connection_url"]
|
||||
if envURL := os.Getenv("VAULT_PG_CONNECTION_URL"); envURL != "" {
|
||||
connURL = envURL
|
||||
}
|
||||
|
||||
return connURL
|
||||
}
|
||||
|
||||
// splitKey is a helper to split a full path key into individual
|
||||
// parts: parentPath, path, key
|
||||
func (m *PostgreSQLBackend) splitKey(fullPath string) (string, string, string) {
|
||||
var parentPath string
|
||||
var path string
|
||||
|
||||
pieces := strings.Split(fullPath, "/")
|
||||
depth := len(pieces)
|
||||
key := pieces[depth-1]
|
||||
|
||||
if depth == 1 {
|
||||
parentPath = ""
|
||||
path = "/"
|
||||
} else if depth == 2 {
|
||||
parentPath = "/"
|
||||
path = "/" + pieces[0] + "/"
|
||||
} else {
|
||||
parentPath = "/" + strings.Join(pieces[:depth-2], "/") + "/"
|
||||
path = "/" + strings.Join(pieces[:depth-1], "/") + "/"
|
||||
}
|
||||
|
||||
return parentPath, path, key
|
||||
}
|
||||
|
||||
// Put is used to insert or update an entry.
|
||||
func (m *PostgreSQLBackend) Put(ctx context.Context, entry *physical.Entry) error {
|
||||
defer metrics.MeasureSince([]string{"postgres", "put"}, time.Now())
|
||||
|
||||
m.permitPool.Acquire()
|
||||
defer m.permitPool.Release()
|
||||
|
||||
parentPath, path, key := m.splitKey(entry.Key)
|
||||
|
||||
_, err := m.client.ExecContext(ctx, m.put_query, parentPath, path, key, entry.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get is used to fetch and entry.
|
||||
func (m *PostgreSQLBackend) Get(ctx context.Context, fullPath string) (*physical.Entry, error) {
|
||||
defer metrics.MeasureSince([]string{"postgres", "get"}, time.Now())
|
||||
|
||||
m.permitPool.Acquire()
|
||||
defer m.permitPool.Release()
|
||||
|
||||
_, path, key := m.splitKey(fullPath)
|
||||
|
||||
var result []byte
|
||||
err := m.client.QueryRowContext(ctx, m.get_query, path, key).Scan(&result)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ent := &physical.Entry{
|
||||
Key: fullPath,
|
||||
Value: result,
|
||||
}
|
||||
return ent, nil
|
||||
}
|
||||
|
||||
// Delete is used to permanently delete an entry
|
||||
func (m *PostgreSQLBackend) Delete(ctx context.Context, fullPath string) error {
|
||||
defer metrics.MeasureSince([]string{"postgres", "delete"}, time.Now())
|
||||
|
||||
m.permitPool.Acquire()
|
||||
defer m.permitPool.Release()
|
||||
|
||||
_, path, key := m.splitKey(fullPath)
|
||||
|
||||
_, err := m.client.ExecContext(ctx, m.delete_query, path, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// List is used to list all the keys under a given
|
||||
// prefix, up to the next prefix.
|
||||
func (m *PostgreSQLBackend) List(ctx context.Context, prefix string) ([]string, error) {
|
||||
defer metrics.MeasureSince([]string{"postgres", "list"}, time.Now())
|
||||
|
||||
m.permitPool.Acquire()
|
||||
defer m.permitPool.Release()
|
||||
|
||||
rows, err := m.client.QueryContext(ctx, m.list_query, "/"+prefix)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var keys []string
|
||||
for rows.Next() {
|
||||
var key string
|
||||
err = rows.Scan(&key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan rows: %w", err)
|
||||
}
|
||||
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// LockWith is used for mutual exclusion based on the given key.
|
||||
func (p *PostgreSQLBackend) LockWith(key, value string) (physical.Lock, error) {
|
||||
identity, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &PostgreSQLLock{
|
||||
backend: p,
|
||||
key: key,
|
||||
value: value,
|
||||
identity: identity,
|
||||
ttlSeconds: PostgreSQLLockTTLSeconds,
|
||||
renewInterval: PostgreSQLLockRenewInterval,
|
||||
retryInterval: PostgreSQLLockRetryInterval,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQLBackend) HAEnabled() bool {
|
||||
return p.haEnabled
|
||||
}
|
||||
|
||||
// Lock tries to acquire the lock by repeatedly trying to create a record in the
|
||||
// PostgreSQL table. It will block until either the stop channel is closed or
|
||||
// the lock could be acquired successfully. The returned channel will be closed
|
||||
// once the lock in the PostgreSQL table cannot be renewed, either due to an
|
||||
// error speaking to PostgreSQL or because someone else has taken it.
|
||||
func (l *PostgreSQLLock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) {
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
|
||||
var (
|
||||
success = make(chan struct{})
|
||||
errors = make(chan error)
|
||||
leader = make(chan struct{})
|
||||
)
|
||||
// try to acquire the lock asynchronously
|
||||
go l.tryToLock(stopCh, success, errors)
|
||||
|
||||
select {
|
||||
case <-success:
|
||||
// after acquiring it successfully, we must renew the lock periodically
|
||||
l.renewTicker = time.NewTicker(l.renewInterval)
|
||||
go l.periodicallyRenewLock(leader)
|
||||
case err := <-errors:
|
||||
return nil, err
|
||||
case <-stopCh:
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return leader, nil
|
||||
}
|
||||
|
||||
// Unlock releases the lock by deleting the lock record from the
|
||||
// PostgreSQL table.
|
||||
func (l *PostgreSQLLock) Unlock() error {
|
||||
pg := l.backend
|
||||
pg.permitPool.Acquire()
|
||||
defer pg.permitPool.Release()
|
||||
|
||||
if l.renewTicker != nil {
|
||||
l.renewTicker.Stop()
|
||||
}
|
||||
|
||||
// Delete lock owned by me
|
||||
_, err := pg.client.Exec(pg.haDeleteLockExec, l.identity, l.key)
|
||||
return err
|
||||
}
|
||||
|
||||
// Value checks whether or not the lock is held by any instance of PostgreSQLLock,
|
||||
// including this one, and returns the current value.
|
||||
func (l *PostgreSQLLock) Value() (bool, string, error) {
|
||||
pg := l.backend
|
||||
pg.permitPool.Acquire()
|
||||
defer pg.permitPool.Release()
|
||||
var result string
|
||||
err := pg.client.QueryRow(pg.haGetLockValueQuery, l.key).Scan(&result)
|
||||
|
||||
switch err {
|
||||
case nil:
|
||||
return true, result, nil
|
||||
case sql.ErrNoRows:
|
||||
return false, "", nil
|
||||
default:
|
||||
return false, "", err
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// tryToLock tries to create a new item in PostgreSQL every `retryInterval`.
|
||||
// As long as the item cannot be created (because it already exists), it will
|
||||
// be retried. If the operation fails due to an error, it is sent to the errors
|
||||
// channel. When the lock could be acquired successfully, the success channel
|
||||
// is closed.
|
||||
func (l *PostgreSQLLock) tryToLock(stop <-chan struct{}, success chan struct{}, errors chan error) {
|
||||
ticker := time.NewTicker(l.retryInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case <-ticker.C:
|
||||
gotlock, err := l.writeItem()
|
||||
switch {
|
||||
case err != nil:
|
||||
errors <- err
|
||||
return
|
||||
case gotlock:
|
||||
close(success)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *PostgreSQLLock) periodicallyRenewLock(done chan struct{}) {
|
||||
for range l.renewTicker.C {
|
||||
gotlock, err := l.writeItem()
|
||||
if err != nil || !gotlock {
|
||||
close(done)
|
||||
l.renewTicker.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Attempts to put/update the PostgreSQL item using condition expressions to
|
||||
// evaluate the TTL. Returns true if the lock was obtained, false if not.
|
||||
// If false error may be nil or non-nil: nil indicates simply that someone
|
||||
// else has the lock, whereas non-nil means that something unexpected happened.
|
||||
func (l *PostgreSQLLock) writeItem() (bool, error) {
|
||||
pg := l.backend
|
||||
pg.permitPool.Acquire()
|
||||
defer pg.permitPool.Release()
|
||||
|
||||
// Try steal lock or update expiry on my lock
|
||||
|
||||
sqlResult, err := pg.client.Exec(pg.haUpsertLockIdentityExec, l.identity, l.key, l.value, l.ttlSeconds)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if sqlResult == nil {
|
||||
return false, fmt.Errorf("empty SQL response received")
|
||||
}
|
||||
|
||||
ar, err := sqlResult.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return ar == 1, nil
|
||||
}
|
|
@ -1,426 +0,0 @@
|
|||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/helper/testhelpers/postgresql"
|
||||
"github.com/hashicorp/vault/sdk/helper/logging"
|
||||
"github.com/hashicorp/vault/sdk/physical"
|
||||
_ "github.com/jackc/pgx/v4/stdlib"
|
||||
)
|
||||
|
||||
func TestPostgreSQLBackend(t *testing.T) {
|
||||
logger := logging.NewVaultLogger(log.Debug)
|
||||
|
||||
// Use docker as pg backend if no url is provided via environment variables
|
||||
connURL := os.Getenv("PGURL")
|
||||
if connURL == "" {
|
||||
cleanup, u := postgresql.PrepareTestContainer(t, "11.1")
|
||||
defer cleanup()
|
||||
connURL = u
|
||||
}
|
||||
|
||||
table := os.Getenv("PGTABLE")
|
||||
if table == "" {
|
||||
table = "vault_kv_store"
|
||||
}
|
||||
|
||||
hae := os.Getenv("PGHAENABLED")
|
||||
if hae == "" {
|
||||
hae = "true"
|
||||
}
|
||||
|
||||
// Run vault tests
|
||||
logger.Info(fmt.Sprintf("Connection URL: %v", connURL))
|
||||
|
||||
b1, err := NewPostgreSQLBackend(map[string]string{
|
||||
"connection_url": connURL,
|
||||
"table": table,
|
||||
"ha_enabled": hae,
|
||||
}, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create new backend: %v", err)
|
||||
}
|
||||
|
||||
b2, err := NewPostgreSQLBackend(map[string]string{
|
||||
"connection_url": connURL,
|
||||
"table": table,
|
||||
"ha_enabled": hae,
|
||||
}, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create new backend: %v", err)
|
||||
}
|
||||
pg := b1.(*PostgreSQLBackend)
|
||||
|
||||
// Read postgres version to test basic connects works
|
||||
var pgversion string
|
||||
if err = pg.client.QueryRow("SELECT current_setting('server_version_num')").Scan(&pgversion); err != nil {
|
||||
t.Fatalf("Failed to check for Postgres version: %v", err)
|
||||
}
|
||||
logger.Info(fmt.Sprintf("Postgres Version: %v", pgversion))
|
||||
|
||||
setupDatabaseObjects(t, logger, pg)
|
||||
|
||||
defer func() {
|
||||
pg := b1.(*PostgreSQLBackend)
|
||||
_, err := pg.client.Exec(fmt.Sprintf(" TRUNCATE TABLE %v ", pg.table))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to truncate table: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
logger.Info("Running basic backend tests")
|
||||
physical.ExerciseBackend(t, b1)
|
||||
logger.Info("Running list prefix backend tests")
|
||||
physical.ExerciseBackend_ListPrefix(t, b1)
|
||||
|
||||
ha1, ok := b1.(physical.HABackend)
|
||||
if !ok {
|
||||
t.Fatalf("PostgreSQLDB does not implement HABackend")
|
||||
}
|
||||
|
||||
ha2, ok := b2.(physical.HABackend)
|
||||
if !ok {
|
||||
t.Fatalf("PostgreSQLDB does not implement HABackend")
|
||||
}
|
||||
|
||||
if ha1.HAEnabled() && ha2.HAEnabled() {
|
||||
logger.Info("Running ha backend tests")
|
||||
physical.ExerciseHABackend(t, ha1, ha2)
|
||||
testPostgresSQLLockTTL(t, ha1)
|
||||
testPostgresSQLLockRenewal(t, ha1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgreSQLBackendMaxIdleConnectionsParameter(t *testing.T) {
|
||||
_, err := NewPostgreSQLBackend(map[string]string{
|
||||
"connection_url": "some connection url",
|
||||
"max_idle_connections": "bad param",
|
||||
}, logging.NewVaultLogger(log.Debug))
|
||||
if err == nil {
|
||||
t.Error("Expected invalid max_idle_connections param to return error")
|
||||
}
|
||||
expectedErrStr := "failed parsing max_idle_connections parameter: strconv.Atoi: parsing \"bad param\": invalid syntax"
|
||||
if err.Error() != expectedErrStr {
|
||||
t.Errorf("Expected: %q but found %q", expectedErrStr, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionURL(t *testing.T) {
|
||||
type input struct {
|
||||
envar string
|
||||
conf map[string]string
|
||||
}
|
||||
|
||||
cases := map[string]struct {
|
||||
want string
|
||||
input input
|
||||
}{
|
||||
"environment_variable_not_set_use_config_value": {
|
||||
want: "abc",
|
||||
input: input{
|
||||
envar: "",
|
||||
conf: map[string]string{"connection_url": "abc"},
|
||||
},
|
||||
},
|
||||
|
||||
"no_value_connection_url_set_key_exists": {
|
||||
want: "",
|
||||
input: input{
|
||||
envar: "",
|
||||
conf: map[string]string{"connection_url": ""},
|
||||
},
|
||||
},
|
||||
|
||||
"no_value_connection_url_set_key_doesnt_exist": {
|
||||
want: "",
|
||||
input: input{
|
||||
envar: "",
|
||||
conf: map[string]string{},
|
||||
},
|
||||
},
|
||||
|
||||
"environment_variable_set": {
|
||||
want: "abc",
|
||||
input: input{
|
||||
envar: "abc",
|
||||
conf: map[string]string{"connection_url": "def"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, tt := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
// This is necessary to avoid always testing the branch where the env is set.
|
||||
// As long the env is set --- even if the value is "" --- `ok` returns true.
|
||||
if tt.input.envar != "" {
|
||||
os.Setenv("VAULT_PG_CONNECTION_URL", tt.input.envar)
|
||||
defer os.Unsetenv("VAULT_PG_CONNECTION_URL")
|
||||
}
|
||||
|
||||
got := connectionURL(tt.input.conf)
|
||||
|
||||
if got != tt.want {
|
||||
t.Errorf("connectionURL(%s): want %q, got %q", tt.input, tt.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Similar to testHABackend, but using internal implementation details to
|
||||
// trigger the lock failure scenario by setting the lock renew period for one
|
||||
// of the locks to a higher value than the lock TTL.
|
||||
const maxTries = 3
|
||||
|
||||
func testPostgresSQLLockTTL(t *testing.T, ha physical.HABackend) {
|
||||
t.Log("Skipping testPostgresSQLLockTTL portion of test.")
|
||||
return
|
||||
|
||||
for tries := 1; tries <= maxTries; tries++ {
|
||||
// Try this several times. If the test environment is too slow the lock can naturally lapse
|
||||
if attemptLockTTLTest(t, ha, tries) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func attemptLockTTLTest(t *testing.T, ha physical.HABackend, tries int) bool {
|
||||
// Set much smaller lock times to speed up the test.
|
||||
lockTTL := 3
|
||||
renewInterval := time.Second * 1
|
||||
retryInterval := time.Second * 1
|
||||
longRenewInterval := time.Duration(lockTTL*2) * time.Second
|
||||
lockkey := "postgresttl"
|
||||
|
||||
var leaderCh <-chan struct{}
|
||||
|
||||
// Get the lock
|
||||
origLock, err := ha.LockWith(lockkey, "bar")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
{
|
||||
// set the first lock renew period to double the expected TTL.
|
||||
lock := origLock.(*PostgreSQLLock)
|
||||
lock.renewInterval = longRenewInterval
|
||||
lock.ttlSeconds = lockTTL
|
||||
|
||||
// Attempt to lock
|
||||
lockTime := time.Now()
|
||||
leaderCh, err = lock.Lock(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if leaderCh == nil {
|
||||
t.Fatalf("failed to get leader ch")
|
||||
}
|
||||
|
||||
if tries == 1 {
|
||||
time.Sleep(3 * time.Second)
|
||||
}
|
||||
// Check the value
|
||||
held, val, err := lock.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !held {
|
||||
if tries < maxTries && time.Since(lockTime) > (time.Second*time.Duration(lockTTL)) {
|
||||
// Our test environment is slow enough that we failed this, retry
|
||||
return false
|
||||
}
|
||||
t.Fatalf("should be held")
|
||||
}
|
||||
if val != "bar" {
|
||||
t.Fatalf("bad value: %v", val)
|
||||
}
|
||||
}
|
||||
|
||||
// Second acquisition should succeed because the first lock should
|
||||
// not renew within the 3 sec TTL.
|
||||
origLock2, err := ha.LockWith(lockkey, "baz")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
{
|
||||
lock2 := origLock2.(*PostgreSQLLock)
|
||||
lock2.renewInterval = renewInterval
|
||||
lock2.ttlSeconds = lockTTL
|
||||
lock2.retryInterval = retryInterval
|
||||
|
||||
// Cancel attempt in 6 sec so as not to block unit tests forever
|
||||
stopCh := make(chan struct{})
|
||||
time.AfterFunc(time.Duration(lockTTL*2)*time.Second, func() {
|
||||
close(stopCh)
|
||||
})
|
||||
|
||||
// Attempt to lock should work
|
||||
lockTime := time.Now()
|
||||
leaderCh2, err := lock2.Lock(stopCh)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if leaderCh2 == nil {
|
||||
t.Fatalf("should get leader ch")
|
||||
}
|
||||
defer lock2.Unlock()
|
||||
|
||||
// Check the value
|
||||
held, val, err := lock2.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !held {
|
||||
if tries < maxTries && time.Since(lockTime) > (time.Second*time.Duration(lockTTL)) {
|
||||
// Our test environment is slow enough that we failed this, retry
|
||||
return false
|
||||
}
|
||||
t.Fatalf("should be held")
|
||||
}
|
||||
if val != "baz" {
|
||||
t.Fatalf("bad value: %v", val)
|
||||
}
|
||||
}
|
||||
// The first lock should have lost the leader channel
|
||||
select {
|
||||
case <-time.After(longRenewInterval * 2):
|
||||
t.Fatalf("original lock did not have its leader channel closed.")
|
||||
case <-leaderCh:
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Verify that once Unlock is called, we don't keep trying to renew the original
|
||||
// lock.
|
||||
func testPostgresSQLLockRenewal(t *testing.T, ha physical.HABackend) {
|
||||
// Get the lock
|
||||
origLock, err := ha.LockWith("pgrenewal", "bar")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// customize the renewal and watch intervals
|
||||
lock := origLock.(*PostgreSQLLock)
|
||||
// lock.renewInterval = time.Second * 1
|
||||
|
||||
// Attempt to lock
|
||||
leaderCh, err := lock.Lock(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if leaderCh == nil {
|
||||
t.Fatalf("failed to get leader ch")
|
||||
}
|
||||
|
||||
// Check the value
|
||||
held, val, err := lock.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !held {
|
||||
t.Fatalf("should be held")
|
||||
}
|
||||
if val != "bar" {
|
||||
t.Fatalf("bad value: %v", val)
|
||||
}
|
||||
|
||||
// Release the lock, which will delete the stored item
|
||||
if err := lock.Unlock(); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Wait longer than the renewal time
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
|
||||
// Attempt to lock with new lock
|
||||
newLock, err := ha.LockWith("pgrenewal", "baz")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
stopCh := make(chan struct{})
|
||||
timeout := time.Duration(lock.ttlSeconds)*time.Second + lock.retryInterval + time.Second
|
||||
|
||||
var leaderCh2 <-chan struct{}
|
||||
newlockch := make(chan struct{})
|
||||
go func() {
|
||||
leaderCh2, err = newLock.Lock(stopCh)
|
||||
close(newlockch)
|
||||
}()
|
||||
|
||||
// Cancel attempt after lock ttl + 1s so as not to block unit tests forever
|
||||
select {
|
||||
case <-time.After(timeout):
|
||||
t.Logf("giving up on lock attempt after %v", timeout)
|
||||
close(stopCh)
|
||||
case <-newlockch:
|
||||
// pass through
|
||||
}
|
||||
|
||||
// Attempt to lock should work
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if leaderCh2 == nil {
|
||||
t.Fatalf("should get leader ch")
|
||||
}
|
||||
|
||||
// Check the value
|
||||
held, val, err = newLock.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !held {
|
||||
t.Fatalf("should be held")
|
||||
}
|
||||
if val != "baz" {
|
||||
t.Fatalf("bad value: %v", val)
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
newLock.Unlock()
|
||||
}
|
||||
|
||||
func setupDatabaseObjects(t *testing.T, logger log.Logger, pg *PostgreSQLBackend) {
|
||||
var err error
|
||||
// Setup tables and indexes if not exists.
|
||||
createTableSQL := fmt.Sprintf(
|
||||
" CREATE TABLE IF NOT EXISTS %v ( "+
|
||||
" parent_path TEXT COLLATE \"C\" NOT NULL, "+
|
||||
" path TEXT COLLATE \"C\", "+
|
||||
" key TEXT COLLATE \"C\", "+
|
||||
" value BYTEA, "+
|
||||
" CONSTRAINT pkey PRIMARY KEY (path, key) "+
|
||||
" ); ", pg.table)
|
||||
|
||||
_, err = pg.client.Exec(createTableSQL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create table: %v", err)
|
||||
}
|
||||
|
||||
createIndexSQL := fmt.Sprintf(" CREATE INDEX IF NOT EXISTS parent_path_idx ON %v (parent_path); ", pg.table)
|
||||
|
||||
_, err = pg.client.Exec(createIndexSQL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create index: %v", err)
|
||||
}
|
||||
|
||||
createHaTableSQL := " CREATE TABLE IF NOT EXISTS vault_ha_locks ( " +
|
||||
" ha_key TEXT COLLATE \"C\" NOT NULL, " +
|
||||
" ha_identity TEXT COLLATE \"C\" NOT NULL, " +
|
||||
" ha_value TEXT COLLATE \"C\", " +
|
||||
" valid_until TIMESTAMP WITH TIME ZONE NOT NULL, " +
|
||||
" CONSTRAINT ha_key PRIMARY KEY (ha_key) " +
|
||||
" ); "
|
||||
|
||||
_, err = pg.client.Exec(createHaTableSQL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create hatable: %v", err)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user