1
0

Compare commits

...

2 Commits

Author SHA1 Message Date
f947ce9080 remove aerospike 2024-07-01 12:23:25 +03:00
f33fd378bb remove cassandra 2024-07-01 12:23:25 +03:00
8 changed files with 0 additions and 786 deletions

View File

@ -10,7 +10,6 @@
/builtin/credential/okta/ @hashicorp/vault-ecosystem
# Secrets engines (pki, ssh, totp and transit omitted)
/builtin/logical/cassandra/ @hashicorp/vault-ecosystem
/builtin/logical/consul/ @hashicorp/vault-ecosystem
/builtin/logical/database/ @hashicorp/vault-ecosystem
/builtin/logical/mysql/ @hashicorp/vault-ecosystem

View File

@ -41,8 +41,6 @@ import (
logicalKv "github.com/hashicorp/vault-plugin-secrets-kv"
logicalDb "github.com/hashicorp/vault/builtin/logical/database"
physAerospike "github.com/hashicorp/vault/physical/aerospike"
physCassandra "github.com/hashicorp/vault/physical/cassandra"
physCockroachDB "github.com/hashicorp/vault/physical/cockroachdb"
physConsul "github.com/hashicorp/vault/physical/consul"
physFoundationDB "github.com/hashicorp/vault/physical/foundationdb"
@ -170,8 +168,6 @@ var (
}
physicalBackends = map[string]physical.Factory{
"aerospike": physAerospike.NewAerospikeBackend,
"cassandra": physCassandra.NewCassandraBackend,
"cockroachdb": physCockroachDB.NewCockroachDBBackend,
"consul": physConsul.NewConsulBackend,
"file_transactional": physFile.NewTransactionalFileBackend,

2
go.mod
View File

@ -26,7 +26,6 @@ replace github.com/hashicorp/vault/sdk => ./sdk
require (
github.com/ProtonMail/go-crypto v0.0.0-20230626094100-7e9e0395ebec
github.com/aerospike/aerospike-client-go/v5 v5.6.0
github.com/apple/foundationdb/bindings/go v0.0.0-20190411004307-cd5c9d91fad2
github.com/armon/go-metrics v0.4.1
github.com/armon/go-radix v1.0.0
@ -355,7 +354,6 @@ require (
github.com/ulikunitz/xz v0.5.10 // indirect
github.com/vmware/govmomi v0.18.0 // indirect
github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 // indirect
github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 // indirect
github.com/yusufpapurcu/wmi v1.2.2 // indirect
github.com/zclconf/go-cty v1.12.1 // indirect
go.mongodb.org/mongo-driver v1.11.6 // indirect

6
go.sum
View File

@ -936,8 +936,6 @@ github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdko
github.com/Shopify/logrus-bugsnag v0.0.0-20171204204709-577dee27f20d/go.mod h1:HI8ITrYtUY+O+ZhtlqUnD8+KwNPOyugEhfP9fdUIaEQ=
github.com/abdullin/seq v0.0.0-20160510034733-d5467c17e7af h1:DBNMBMuMiWYu0b+8KMJuWmfCkcxl09JwdlqwDZZ6U14=
github.com/abdullin/seq v0.0.0-20160510034733-d5467c17e7af/go.mod h1:5Jv4cbFiHJMsVxt52+i0Ha45fjshj6wxYr1r19tB9bw=
github.com/aerospike/aerospike-client-go/v5 v5.6.0 h1:tRxcUq0HY8fFPQEzF3EgrknF+w1xFO0YDfUb9Nm8yRI=
github.com/aerospike/aerospike-client-go/v5 v5.6.0/go.mod h1:rJ/KpmClE7kiBPfvAPrGw9WuNOiz8v2uKbQaUyYPXtI=
github.com/agext/levenshtein v1.2.1 h1:QmvMAjj2aEICytGiWzmxoE0x2KZvE0fvmqMOfy2tjT8=
github.com/agext/levenshtein v1.2.1/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558=
github.com/agnivade/levenshtein v1.0.1/go.mod h1:CURSv5d9Uaml+FovSIICkLbAUZ9S4RqaHDIsdSBg7lM=
@ -2875,9 +2873,6 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/yuin/gopher-lua v0.0.0-20200816102855-ee81675732da/go.mod h1:E1AXubJBdNmFERAOucpDIxNzeGfLzg0mYh+UfMWdChA=
github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 h1:k/gmLsJDWwWqbLCur2yWnJzwQEKRcAHXo6seXGuSwWw=
github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9/go.mod h1:E1AXubJBdNmFERAOucpDIxNzeGfLzg0mYh+UfMWdChA=
github.com/yusufpapurcu/wmi v1.2.2 h1:KBNDSne4vP5mbSWnJbO+51IMOXJB67QiYCSBrubbPRg=
github.com/yusufpapurcu/wmi v1.2.2/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
github.com/yvasiyarov/go-metrics v0.0.0-20140926110328-57bccd1ccd43/go.mod h1:aX5oPXxHm3bOH+xeAttToC8pqch2ScQN/JoXYupl6xs=
@ -3316,7 +3311,6 @@ golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190209173611-3b5209105503/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=

View File

@ -1,254 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package aerospike
import (
"context"
"crypto/sha256"
"fmt"
"strconv"
"strings"
"time"
aero "github.com/aerospike/aerospike-client-go/v5"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/sdk/physical"
)
const (
keyBin = "keyBin"
valueBin = "valueBin"
defaultNamespace = "test"
defaultHostname = "127.0.0.1"
defaultPort = 3000
keyNotFoundError = "Key not found"
)
// AerospikeBackend is a physical backend that stores data in Aerospike.
type AerospikeBackend struct {
client *aero.Client
namespace string
set string
logger log.Logger
}
// Verify AerospikeBackend satisfies the correct interface.
var _ physical.Backend = (*AerospikeBackend)(nil)
// NewAerospikeBackend constructs an AerospikeBackend backend.
func NewAerospikeBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
namespace, ok := conf["namespace"]
if !ok {
namespace = defaultNamespace
}
set := conf["set"]
policy, err := buildClientPolicy(conf)
if err != nil {
return nil, err
}
client, err := buildAerospikeClient(conf, policy)
if err != nil {
return nil, err
}
return &AerospikeBackend{
client: client,
namespace: namespace,
set: set,
logger: logger,
}, nil
}
func buildAerospikeClient(conf map[string]string, policy *aero.ClientPolicy) (*aero.Client, error) {
hostListString, ok := conf["hostlist"]
if !ok || hostListString == "" {
hostname, ok := conf["hostname"]
if !ok || hostname == "" {
hostname = defaultHostname
}
portString, ok := conf["port"]
if !ok || portString == "" {
portString = strconv.Itoa(defaultPort)
}
port, err := strconv.Atoi(portString)
if err != nil {
return nil, err
}
return aero.NewClientWithPolicy(policy, hostname, port)
}
hostList, err := parseHostList(hostListString)
if err != nil {
return nil, err
}
return aero.NewClientWithPolicyAndHost(policy, hostList...)
}
func buildClientPolicy(conf map[string]string) (*aero.ClientPolicy, error) {
policy := aero.NewClientPolicy()
policy.User = conf["username"]
policy.Password = conf["password"]
authMode := aero.AuthModeInternal
if mode, ok := conf["auth_mode"]; ok {
switch strings.ToUpper(mode) {
case "EXTERNAL":
authMode = aero.AuthModeExternal
case "INTERNAL":
authMode = aero.AuthModeInternal
default:
return nil, fmt.Errorf("'auth_mode' must be one of {INTERNAL, EXTERNAL}")
}
}
policy.AuthMode = authMode
policy.ClusterName = conf["cluster_name"]
if timeoutString, ok := conf["timeout"]; ok {
timeout, err := strconv.Atoi(timeoutString)
if err != nil {
return nil, err
}
policy.Timeout = time.Duration(timeout) * time.Millisecond
}
if idleTimeoutString, ok := conf["idle_timeout"]; ok {
idleTimeout, err := strconv.Atoi(idleTimeoutString)
if err != nil {
return nil, err
}
policy.IdleTimeout = time.Duration(idleTimeout) * time.Millisecond
}
return policy, nil
}
func (a *AerospikeBackend) key(userKey string) (*aero.Key, error) {
return aero.NewKey(a.namespace, a.set, hash(userKey))
}
// Put is used to insert or update an entry.
func (a *AerospikeBackend) Put(_ context.Context, entry *physical.Entry) error {
aeroKey, err := a.key(entry.Key)
if err != nil {
return err
}
// replace the Aerospike record if exists
writePolicy := aero.NewWritePolicy(0, 0)
writePolicy.RecordExistsAction = aero.REPLACE
binMap := make(aero.BinMap, 2)
binMap[keyBin] = entry.Key
binMap[valueBin] = entry.Value
return a.client.Put(writePolicy, aeroKey, binMap)
}
// Get is used to fetch an entry.
func (a *AerospikeBackend) Get(_ context.Context, key string) (*physical.Entry, error) {
aeroKey, err := a.key(key)
if err != nil {
return nil, err
}
record, err := a.client.Get(nil, aeroKey)
if err != nil {
if strings.Contains(err.Error(), keyNotFoundError) {
return nil, nil
}
return nil, err
}
value, ok := record.Bins[valueBin]
if !ok {
return nil, fmt.Errorf("Value bin was not found in the record")
}
return &physical.Entry{
Key: key,
Value: value.([]byte),
}, nil
}
// Delete is used to permanently delete an entry.
func (a *AerospikeBackend) Delete(_ context.Context, key string) error {
aeroKey, err := a.key(key)
if err != nil {
return err
}
_, err = a.client.Delete(nil, aeroKey)
return err
}
// List is used to list all the keys under a given
// prefix, up to the next prefix.
func (a *AerospikeBackend) List(_ context.Context, prefix string) ([]string, error) {
recordSet, err := a.client.ScanAll(nil, a.namespace, a.set)
if err != nil {
return nil, err
}
var keyList []string
for res := range recordSet.Results() {
if res.Err != nil {
return nil, res.Err
}
recordKey := res.Record.Bins[keyBin].(string)
if strings.HasPrefix(recordKey, prefix) {
trimPrefix := strings.TrimPrefix(recordKey, prefix)
keys := strings.Split(trimPrefix, "/")
if len(keys) == 1 {
keyList = append(keyList, keys[0])
} else {
withSlash := keys[0] + "/"
if !strutil.StrListContains(keyList, withSlash) {
keyList = append(keyList, withSlash)
}
}
}
}
return keyList, nil
}
func parseHostList(list string) ([]*aero.Host, error) {
hosts := strings.Split(list, ",")
var hostList []*aero.Host
for _, host := range hosts {
if host == "" {
continue
}
split := strings.Split(host, ":")
switch len(split) {
case 1:
hostList = append(hostList, aero.NewHost(split[0], defaultPort))
case 2:
port, err := strconv.Atoi(split[1])
if err != nil {
return nil, err
}
hostList = append(hostList, aero.NewHost(split[0], port))
default:
return nil, fmt.Errorf("Invalid 'hostlist' configuration")
}
}
return hostList, nil
}
func hash(s string) string {
hash := sha256.Sum256([]byte(s))
return fmt.Sprintf("%x", hash[:])
}

View File

@ -1,93 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package aerospike
import (
"context"
"math/bits"
"testing"
"time"
aero "github.com/aerospike/aerospike-client-go/v5"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/docker"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/physical"
)
func TestAerospikeBackend(t *testing.T) {
if bits.UintSize == 32 {
t.Skip("Aerospike storage is only supported on 64-bit architectures")
}
cleanup, config := prepareAerospikeContainer(t)
defer cleanup()
logger := logging.NewVaultLogger(log.Debug)
b, err := NewAerospikeBackend(map[string]string{
"hostname": config.hostname,
"port": config.port,
"namespace": config.namespace,
"set": config.set,
}, logger)
if err != nil {
t.Fatalf("err: %s", err)
}
physical.ExerciseBackend(t, b)
physical.ExerciseBackend_ListPrefix(t, b)
}
type aerospikeConfig struct {
hostname string
port string
namespace string
set string
}
func prepareAerospikeContainer(t *testing.T) (func(), *aerospikeConfig) {
runner, err := docker.NewServiceRunner(docker.RunOptions{
ImageRepo: "docker.mirror.hashicorp.services/aerospike/aerospike-server",
ContainerName: "aerospikedb",
ImageTag: "5.6.0.5",
Ports: []string{"3000/tcp", "3001/tcp", "3002/tcp", "3003/tcp"},
})
if err != nil {
t.Fatalf("Could not start local Aerospike: %s", err)
}
svc, err := runner.StartService(context.Background(),
func(ctx context.Context, host string, port int) (docker.ServiceConfig, error) {
cfg := docker.NewServiceHostPort(host, port)
time.Sleep(time.Second)
client, err := aero.NewClient(host, port)
if err != nil {
return nil, err
}
node, err := client.Cluster().GetRandomNode()
if err != nil {
return nil, err
}
_, err = node.RequestInfo(aero.NewInfoPolicy(), "namespaces")
if err != nil {
return nil, err
}
return cfg, nil
},
)
if err != nil {
t.Fatalf("Could not start local Aerospike: %s", err)
}
return svc.Cleanup, &aerospikeConfig{
hostname: svc.Config.URL().Hostname(),
port: svc.Config.URL().Port(),
namespace: "test",
set: "vault",
}
}

View File

@ -1,366 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package cassandra
import (
"context"
"crypto/tls"
"fmt"
"io/ioutil"
"net"
"strconv"
"strings"
"time"
metrics "github.com/armon/go-metrics"
"github.com/gocql/gocql"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/physical"
)
// CassandraBackend is a physical backend that stores data in Cassandra.
type CassandraBackend struct {
sess *gocql.Session
table string
logger log.Logger
}
// Verify CassandraBackend satisfies the correct interfaces
var _ physical.Backend = (*CassandraBackend)(nil)
// NewCassandraBackend constructs a Cassandra backend using a pre-existing
// keyspace and table.
func NewCassandraBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
splitArray := func(v string) []string {
return strings.FieldsFunc(v, func(r rune) bool {
return r == ','
})
}
var (
hosts = splitArray(conf["hosts"])
port = 9042
explicitPort = false
keyspace = conf["keyspace"]
table = conf["table"]
consistency = gocql.LocalQuorum
)
if len(hosts) == 0 {
hosts = []string{"localhost"}
}
for i, hp := range hosts {
h, ps, err := net.SplitHostPort(hp)
if err != nil {
continue
}
p, err := strconv.Atoi(ps)
if err != nil {
return nil, err
}
if explicitPort && p != port {
return nil, fmt.Errorf("all hosts must have the same port")
}
hosts[i], port = h, p
explicitPort = true
}
if keyspace == "" {
keyspace = "vault"
}
if table == "" {
table = "entries"
}
if cs, ok := conf["consistency"]; ok {
switch cs {
case "ANY":
consistency = gocql.Any
case "ONE":
consistency = gocql.One
case "TWO":
consistency = gocql.Two
case "THREE":
consistency = gocql.Three
case "QUORUM":
consistency = gocql.Quorum
case "ALL":
consistency = gocql.All
case "LOCAL_QUORUM":
consistency = gocql.LocalQuorum
case "EACH_QUORUM":
consistency = gocql.EachQuorum
case "LOCAL_ONE":
consistency = gocql.LocalOne
default:
return nil, fmt.Errorf("'consistency' must be one of {ANY, ONE, TWO, THREE, QUORUM, ALL, LOCAL_QUORUM, EACH_QUORUM, LOCAL_ONE}")
}
}
connectStart := time.Now()
cluster := gocql.NewCluster(hosts...)
cluster.Port = port
cluster.Keyspace = keyspace
if retryCountStr, ok := conf["simple_retry_policy_retries"]; ok {
retryCount, err := strconv.Atoi(retryCountStr)
if err != nil || retryCount <= 0 {
return nil, fmt.Errorf("'simple_retry_policy_retries' must be a positive integer")
}
cluster.RetryPolicy = &gocql.SimpleRetryPolicy{NumRetries: retryCount}
}
cluster.ProtoVersion = 2
if protoVersionStr, ok := conf["protocol_version"]; ok {
protoVersion, err := strconv.Atoi(protoVersionStr)
if err != nil {
return nil, fmt.Errorf("'protocol_version' must be an integer")
}
cluster.ProtoVersion = protoVersion
}
if username, ok := conf["username"]; ok {
if cluster.ProtoVersion < 2 {
return nil, fmt.Errorf("authentication is not supported with protocol version < 2")
}
authenticator := gocql.PasswordAuthenticator{Username: username}
if password, ok := conf["password"]; ok {
authenticator.Password = password
}
cluster.Authenticator = authenticator
}
if initialConnectionTimeoutStr, ok := conf["initial_connection_timeout"]; ok {
initialConnectionTimeout, err := strconv.Atoi(initialConnectionTimeoutStr)
if err != nil || initialConnectionTimeout <= 0 {
return nil, fmt.Errorf("'initial_connection_timeout' must be a positive integer")
}
cluster.ConnectTimeout = time.Duration(initialConnectionTimeout) * time.Second
}
if connTimeoutStr, ok := conf["connection_timeout"]; ok {
connectionTimeout, err := strconv.Atoi(connTimeoutStr)
if err != nil || connectionTimeout <= 0 {
return nil, fmt.Errorf("'connection_timeout' must be a positive integer")
}
cluster.Timeout = time.Duration(connectionTimeout) * time.Second
}
if err := setupCassandraTLS(conf, cluster); err != nil {
return nil, err
}
sess, err := cluster.CreateSession()
if err != nil {
return nil, err
}
metrics.MeasureSince([]string{"cassandra", "connect"}, connectStart)
sess.SetConsistency(consistency)
impl := &CassandraBackend{
sess: sess,
table: table,
logger: logger,
}
return impl, nil
}
func setupCassandraTLS(conf map[string]string, cluster *gocql.ClusterConfig) error {
tlsOnStr, ok := conf["tls"]
if !ok {
return nil
}
tlsOn, err := strconv.Atoi(tlsOnStr)
if err != nil {
return fmt.Errorf("'tls' must be an integer (0 or 1)")
}
if tlsOn == 0 {
return nil
}
tlsConfig := &tls.Config{}
if pemBundlePath, ok := conf["pem_bundle_file"]; ok {
pemBundleData, err := ioutil.ReadFile(pemBundlePath)
if err != nil {
return fmt.Errorf("error reading pem bundle from %q: %w", pemBundlePath, err)
}
pemBundle, err := certutil.ParsePEMBundle(string(pemBundleData))
if err != nil {
return fmt.Errorf("error parsing 'pem_bundle': %w", err)
}
tlsConfig, err = pemBundle.GetTLSConfig(certutil.TLSClient)
if err != nil {
return err
}
} else if pemJSONPath, ok := conf["pem_json_file"]; ok {
pemJSONData, err := ioutil.ReadFile(pemJSONPath)
if err != nil {
return fmt.Errorf("error reading json bundle from %q: %w", pemJSONPath, err)
}
pemJSON, err := certutil.ParsePKIJSON([]byte(pemJSONData))
if err != nil {
return err
}
tlsConfig, err = pemJSON.GetTLSConfig(certutil.TLSClient)
if err != nil {
return err
}
}
if tlsSkipVerifyStr, ok := conf["tls_skip_verify"]; ok {
tlsSkipVerify, err := strconv.Atoi(tlsSkipVerifyStr)
if err != nil {
return fmt.Errorf("'tls_skip_verify' must be an integer (0 or 1)")
}
if tlsSkipVerify == 0 {
tlsConfig.InsecureSkipVerify = false
} else {
tlsConfig.InsecureSkipVerify = true
}
}
if tlsMinVersion, ok := conf["tls_min_version"]; ok {
switch tlsMinVersion {
case "tls10":
tlsConfig.MinVersion = tls.VersionTLS10
case "tls11":
tlsConfig.MinVersion = tls.VersionTLS11
case "tls12":
tlsConfig.MinVersion = tls.VersionTLS12
case "tls13":
tlsConfig.MinVersion = tls.VersionTLS13
default:
return fmt.Errorf("'tls_min_version' must be one of `tls10`, `tls11`, `tls12` or `tls13`")
}
}
cluster.SslOpts = &gocql.SslOptions{
Config: tlsConfig,
EnableHostVerification: !tlsConfig.InsecureSkipVerify,
}
return nil
}
// bucketName sanitises a bucket name for Cassandra
func (c *CassandraBackend) bucketName(name string) string {
if name == "" {
name = "."
}
return strings.TrimRight(name, "/")
}
// bucket returns all the prefix buckets the key should be stored at
func (c *CassandraBackend) buckets(key string) []string {
vals := append([]string{""}, physical.Prefixes(key)...)
for i, v := range vals {
vals[i] = c.bucketName(v)
}
return vals
}
// bucket returns the most specific bucket for the key
func (c *CassandraBackend) bucket(key string) string {
bs := c.buckets(key)
return bs[len(bs)-1]
}
// Put is used to insert or update an entry
func (c *CassandraBackend) Put(ctx context.Context, entry *physical.Entry) error {
defer metrics.MeasureSince([]string{"cassandra", "put"}, time.Now())
// Execute inserts to each key prefix simultaneously
stmt := fmt.Sprintf(`INSERT INTO "%s" (bucket, key, value) VALUES (?, ?, ?)`, c.table)
buckets := c.buckets(entry.Key)
results := make(chan error, len(buckets))
for i, _bucket := range buckets {
go func(i int, bucket string) {
var value []byte
if i == len(buckets)-1 {
// Only store the full value if this is the leaf bucket where the entry will actually be read
// otherwise this write is just to allow for list operations
value = entry.Value
}
results <- c.sess.Query(stmt, bucket, entry.Key, value).Exec()
}(i, _bucket)
}
for i := 0; i < len(buckets); i++ {
if err := <-results; err != nil {
return err
}
}
return nil
}
// Get is used to fetch an entry
func (c *CassandraBackend) Get(ctx context.Context, key string) (*physical.Entry, error) {
defer metrics.MeasureSince([]string{"cassandra", "get"}, time.Now())
v := []byte(nil)
stmt := fmt.Sprintf(`SELECT value FROM "%s" WHERE bucket = ? AND key = ? LIMIT 1`, c.table)
q := c.sess.Query(stmt, c.bucket(key), key)
if err := q.Scan(&v); err != nil {
if err == gocql.ErrNotFound {
return nil, nil
}
return nil, err
}
return &physical.Entry{
Key: key,
Value: v,
}, nil
}
// Delete is used to permanently delete an entry
func (c *CassandraBackend) Delete(ctx context.Context, key string) error {
defer metrics.MeasureSince([]string{"cassandra", "delete"}, time.Now())
stmt := fmt.Sprintf(`DELETE FROM "%s" WHERE bucket = ? AND key = ?`, c.table)
buckets := c.buckets(key)
results := make(chan error, len(buckets))
for _, bucket := range buckets {
go func(bucket string) {
results <- c.sess.Query(stmt, bucket, key).Exec()
}(bucket)
}
for i := 0; i < len(buckets); i++ {
if err := <-results; err != nil {
return err
}
}
return nil
}
// List is used ot list all the keys under a given
// prefix, up to the next prefix.
func (c *CassandraBackend) List(ctx context.Context, prefix string) ([]string, error) {
defer metrics.MeasureSince([]string{"cassandra", "list"}, time.Now())
stmt := fmt.Sprintf(`SELECT key FROM "%s" WHERE bucket = ?`, c.table)
q := c.sess.Query(stmt, c.bucketName(prefix))
iter := q.Iter()
k, keys := "", []string{}
for iter.Scan(&k) {
// Only return the next "component" (with a trailing slash if it has children)
k = strings.TrimPrefix(k, prefix)
if parts := strings.SplitN(k, "/", 2); len(parts) > 1 {
k = parts[0] + "/"
} else {
k = parts[0]
}
// Deduplicate; this works because the keys are sorted
if len(keys) > 0 && keys[len(keys)-1] == k {
continue
}
keys = append(keys, k)
}
return keys, iter.Close()
}

View File

@ -1,60 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package cassandra
import (
"os"
"reflect"
"testing"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/helper/testhelpers/cassandra"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/physical"
)
func TestCassandraBackend(t *testing.T) {
if testing.Short() {
t.Skipf("skipping in short mode")
}
if os.Getenv("VAULT_CI_GO_TEST_RACE") != "" {
t.Skip("skipping race test in CI pending https://github.com/gocql/gocql/pull/1474")
}
host, cleanup := cassandra.PrepareTestContainer(t)
defer cleanup()
// Run vault tests
logger := logging.NewVaultLogger(log.Debug)
b, err := NewCassandraBackend(map[string]string{
"hosts": host.ConnectionURL(),
"protocol_version": "3",
"connection_timeout": "5",
"initial_connection_timeout": "5",
"simple_retry_policy_retries": "3",
}, logger)
if err != nil {
t.Fatalf("Failed to create new backend: %v", err)
}
physical.ExerciseBackend(t, b)
physical.ExerciseBackend_ListPrefix(t, b)
}
func TestCassandraBackendBuckets(t *testing.T) {
expectations := map[string][]string{
"": {"."},
"a": {"."},
"a/b": {".", "a"},
"a/b/c/d/e": {".", "a", "a/b", "a/b/c", "a/b/c/d"},
}
b := &CassandraBackend{}
for input, expected := range expectations {
actual := b.buckets(input)
if !reflect.DeepEqual(actual, expected) {
t.Errorf("bad: %v expected: %v", actual, expected)
}
}
}