1
0

Vault Agent Cache (#6220)

* vault-agent-cache: squashed 250+ commits

* Add proper token revocation validations to the tests

* Add more test cases

* Avoid leaking by not closing request/response bodies; add comments

* Fix revoke orphan use case; update tests

* Add CLI test for making request over unix socket

* agent/cache: remove namespace-related tests

* Strip-off the auto-auth token from the lookup response

* Output listener details along with configuration

* Add scheme to API address output

* leasecache: use IndexNameLease for prefix lease revocations

* Make CLI accept the fully qualified unix address

* export VAULT_AGENT_ADDR=unix://path/to/socket

* unix:/ to unix://
This commit is contained in:
Vishal Nayak 2019-02-14 20:10:36 -05:00 committed by GitHub
parent 5dd50ef281
commit e39a5f28df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 4283 additions and 23 deletions

2
.gitignore vendored
View File

@ -48,7 +48,9 @@ Vagrantfile
# Configs
*.hcl
!command/agent/config/test-fixtures/config.hcl
!command/agent/config/test-fixtures/config-cache.hcl
!command/agent/config/test-fixtures/config-embedded-type.hcl
!command/agent/config/test-fixtures/config-cache-embedded-type.hcl
.DS_Store
.idea

View File

@ -25,6 +25,7 @@ import (
"golang.org/x/time/rate"
)
const EnvVaultAgentAddress = "VAULT_AGENT_ADDR"
const EnvVaultAddress = "VAULT_ADDR"
const EnvVaultCACert = "VAULT_CACERT"
const EnvVaultCAPath = "VAULT_CAPATH"
@ -237,6 +238,10 @@ func (c *Config) ReadEnvironment() error {
if v := os.Getenv(EnvVaultAddress); v != "" {
envAddress = v
}
// Agent's address will take precedence over Vault's address
if v := os.Getenv(EnvVaultAgentAddress); v != "" {
envAddress = v
}
if v := os.Getenv(EnvVaultMaxRetries); v != "" {
maxRetries, err := strconv.ParseUint(v, 10, 32)
if err != nil {
@ -366,6 +371,21 @@ func NewClient(c *Config) (*Client, error) {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
// If address begins with a `unix://`, treat it as a socket file path and set
// the HttpClient's transport to the corresponding socket dialer.
if strings.HasPrefix(c.Address, "unix://") {
socketFilePath := strings.TrimPrefix(c.Address, "unix://")
c.HttpClient = &http.Client{
Transport: &http.Transport{
DialContext: func(context.Context, string, string) (net.Conn, error) {
return net.Dial("unix", socketFilePath)
},
},
}
// Set the unix address for URL parsing below
c.Address = "http://unix"
}
u, err := url.Parse(c.Address)
if err != nil {
return nil, err
@ -707,7 +727,7 @@ func (c *Client) RawRequestWithContext(ctx context.Context, r *Request) (*Respon
redirectCount := 0
START:
req, err := r.toRetryableHTTP()
req, err := r.ToRetryableHTTP()
if err != nil {
return nil, err
}

View File

@ -62,7 +62,7 @@ func (r *Request) ResetJSONBody() error {
// DEPRECATED: ToHTTP turns this request into a valid *http.Request for use
// with the net/http package.
func (r *Request) ToHTTP() (*http.Request, error) {
req, err := r.toRetryableHTTP()
req, err := r.ToRetryableHTTP()
if err != nil {
return nil, err
}
@ -85,7 +85,7 @@ func (r *Request) ToHTTP() (*http.Request, error) {
return req.Request, nil
}
func (r *Request) toRetryableHTTP() (*retryablehttp.Request, error) {
func (r *Request) ToRetryableHTTP() (*retryablehttp.Request, error) {
// Encode the query parameters
r.URL.RawQuery = r.Params.Encode()

View File

@ -292,6 +292,7 @@ type SecretAuth struct {
TokenPolicies []string `json:"token_policies"`
IdentityPolicies []string `json:"identity_policies"`
Metadata map[string]string `json:"metadata"`
Orphan bool `json:"orphan"`
LeaseDuration int `json:"lease_duration"`
Renewable bool `json:"renewable"`

View File

@ -4,6 +4,10 @@ import (
"context"
"fmt"
"io"
"net"
"net/http"
"time"
"os"
"sort"
"strings"
@ -23,6 +27,7 @@ import (
"github.com/hashicorp/vault/command/agent/auth/gcp"
"github.com/hashicorp/vault/command/agent/auth/jwt"
"github.com/hashicorp/vault/command/agent/auth/kubernetes"
"github.com/hashicorp/vault/command/agent/cache"
"github.com/hashicorp/vault/command/agent/config"
"github.com/hashicorp/vault/command/agent/sink"
"github.com/hashicorp/vault/command/agent/sink/file"
@ -218,19 +223,6 @@ func (c *AgentCommand) Run(args []string) int {
info["cgo"] = "enabled"
}
// Server configuration output
padding := 24
sort.Strings(infoKeys)
c.UI.Output("==> Vault agent configuration:\n")
for _, k := range infoKeys {
c.UI.Output(fmt.Sprintf(
"%s%s: %s",
strings.Repeat(" ", padding-len(k)),
strings.Title(k),
info[k]))
}
c.UI.Output("")
// Tests might not want to start a vault server and just want to verify
// the configuration.
if c.flagTestVerifyOnly {
@ -332,10 +324,92 @@ func (c *AgentCommand) Run(args []string) int {
EnableReauthOnNewCredentials: config.AutoAuth.EnableReauthOnNewCredentials,
})
// Start things running
// Start auto-auth and sink servers
go ah.Run(ctx, method)
go ss.Run(ctx, ah.OutputCh, sinks)
// Parse agent listener configurations
if config.Cache != nil && len(config.Cache.Listeners) != 0 {
cacheLogger := c.logger.Named("cache")
// Create the API proxier
apiProxy := cache.NewAPIProxy(&cache.APIProxyConfig{
Logger: cacheLogger.Named("apiproxy"),
})
// Create the lease cache proxier and set its underlying proxier to
// the API proxier.
leaseCache, err := cache.NewLeaseCache(&cache.LeaseCacheConfig{
BaseContext: ctx,
Proxier: apiProxy,
Logger: cacheLogger.Named("leasecache"),
})
if err != nil {
c.UI.Error(fmt.Sprintf("Error creating lease cache: %v", err))
return 1
}
// Create a muxer and add paths relevant for the lease cache layer
mux := http.NewServeMux()
mux.Handle("/v1/agent/cache-clear", leaseCache.HandleCacheClear(ctx))
mux.Handle("/", cache.Handler(ctx, cacheLogger, leaseCache, config.Cache.UseAutoAuthToken, c.client))
var listeners []net.Listener
for i, lnConfig := range config.Cache.Listeners {
listener, props, _, err := cache.ServerListener(lnConfig, c.logWriter, c.UI)
if err != nil {
c.UI.Error(fmt.Sprintf("Error parsing listener configuration: %v", err))
return 1
}
listeners = append(listeners, listener)
scheme := "https://"
if props["tls"] == "disabled" {
scheme = "http://"
}
if lnConfig.Type == "unix" {
scheme = "unix://"
}
infoKey := fmt.Sprintf("api address %d", i+1)
info[infoKey] = scheme + listener.Addr().String()
infoKeys = append(infoKeys, infoKey)
cacheLogger.Info("starting listener", "addr", listener.Addr().String())
server := &http.Server{
Handler: mux,
ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 30 * time.Second,
IdleTimeout: 5 * time.Minute,
ErrorLog: cacheLogger.StandardLogger(nil),
}
go server.Serve(listener)
}
// Ensure that listeners are closed at all the exits
listenerCloseFunc := func() {
for _, ln := range listeners {
ln.Close()
}
}
defer c.cleanupGuard.Do(listenerCloseFunc)
}
// Server configuration output
padding := 24
sort.Strings(infoKeys)
c.UI.Output("==> Vault agent configuration:\n")
for _, k := range infoKeys {
c.UI.Output(fmt.Sprintf(
"%s%s: %s",
strings.Repeat(" ", padding-len(k)),
strings.Title(k),
info[k]))
}
c.UI.Output("")
// Release the log gate.
c.logGate.Flush()

61
command/agent/cache/api_proxy.go vendored Normal file
View File

@ -0,0 +1,61 @@
package cache
import (
"bytes"
"context"
"io/ioutil"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
)
// APIProxy is an implementation of the proxier interface that is used to
// forward the request to Vault and get the response.
type APIProxy struct {
logger hclog.Logger
}
type APIProxyConfig struct {
Logger hclog.Logger
}
func NewAPIProxy(config *APIProxyConfig) Proxier {
return &APIProxy{
logger: config.Logger,
}
}
func (ap *APIProxy) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) {
client, err := api.NewClient(api.DefaultConfig())
if err != nil {
return nil, err
}
client.SetToken(req.Token)
client.SetHeaders(req.Request.Header)
fwReq := client.NewRequest(req.Request.Method, req.Request.URL.Path)
fwReq.BodyBytes = req.RequestBody
// Make the request to Vault and get the response
ap.logger.Info("forwarding request", "path", req.Request.URL.Path, "method", req.Request.Method)
resp, err := client.RawRequestWithContext(ctx, fwReq)
if err != nil {
return nil, err
}
// Parse and reset response body
respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
ap.logger.Error("failed to read request body", "error", err)
return nil, err
}
if resp.Body != nil {
resp.Body.Close()
}
resp.Body = ioutil.NopCloser(bytes.NewBuffer(respBody))
return &SendResponse{
Response: resp,
ResponseBody: respBody,
}, nil
}

43
command/agent/cache/api_proxy_test.go vendored Normal file
View File

@ -0,0 +1,43 @@
package cache
import (
"testing"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/helper/logging"
"github.com/hashicorp/vault/helper/namespace"
)
func TestCache_APIProxy(t *testing.T) {
cleanup, client, _, _ := setupClusterAndAgent(namespace.RootContext(nil), t, nil)
defer cleanup()
proxier := NewAPIProxy(&APIProxyConfig{
Logger: logging.NewVaultLogger(hclog.Trace),
})
r := client.NewRequest("GET", "/v1/sys/health")
req, err := r.ToRetryableHTTP()
if err != nil {
t.Fatal(err)
}
resp, err := proxier.Send(namespace.RootContext(nil), &SendRequest{
Request: req.Request,
})
if err != nil {
t.Fatal(err)
}
var result api.HealthResponse
err = jsonutil.DecodeJSONFromReader(resp.Response.Body, &result)
if err != nil {
t.Fatal(err)
}
if !result.Initialized || result.Sealed || result.Standby {
t.Fatalf("bad sys/health response")
}
}

926
command/agent/cache/cache_test.go vendored Normal file
View File

@ -0,0 +1,926 @@
package cache
import (
"context"
"fmt"
"net"
"net/http"
"os"
"testing"
"time"
"github.com/hashicorp/vault/logical"
"github.com/go-test/deep"
hclog "github.com/hashicorp/go-hclog"
kv "github.com/hashicorp/vault-plugin-secrets-kv"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/credential/userpass"
"github.com/hashicorp/vault/helper/logging"
"github.com/hashicorp/vault/helper/namespace"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/vault"
)
const policyAdmin = `
path "*" {
capabilities = ["sudo", "create", "read", "update", "delete", "list"]
}
`
// setupClusterAndAgent is a helper func used to set up a test cluster and
// caching agent. It returns a cleanup func that should be deferred immediately
// along with two clients, one for direct cluster communication and another to
// talk to the caching agent.
func setupClusterAndAgent(ctx context.Context, t *testing.T, coreConfig *vault.CoreConfig) (func(), *api.Client, *api.Client, *LeaseCache) {
t.Helper()
if ctx == nil {
ctx = context.Background()
}
// Handle sane defaults
if coreConfig == nil {
coreConfig = &vault.CoreConfig{
DisableMlock: true,
DisableCache: true,
Logger: logging.NewVaultLogger(hclog.Trace),
CredentialBackends: map[string]logical.Factory{
"userpass": userpass.Factory,
},
}
}
if coreConfig.CredentialBackends == nil {
coreConfig.CredentialBackends = map[string]logical.Factory{
"userpass": userpass.Factory,
}
}
// Init new test cluster
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
cores := cluster.Cores
vault.TestWaitActive(t, cores[0].Core)
// clusterClient is the client that is used to talk directly to the cluster.
clusterClient := cores[0].Client
// Add an admin policy
if err := clusterClient.Sys().PutPolicy("admin", policyAdmin); err != nil {
t.Fatal(err)
}
// Set up the userpass auth backend and an admin user. Used for getting a token
// for the agent later down in this func.
clusterClient.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{
Type: "userpass",
})
_, err := clusterClient.Logical().Write("auth/userpass/users/foo", map[string]interface{}{
"password": "bar",
"policies": []string{"admin"},
})
if err != nil {
t.Fatal(err)
}
// Set up env vars for agent consumption
origEnvVaultAddress := os.Getenv(api.EnvVaultAddress)
os.Setenv(api.EnvVaultAddress, clusterClient.Address())
origEnvVaultCACert := os.Getenv(api.EnvVaultCACert)
os.Setenv(api.EnvVaultCACert, fmt.Sprintf("%s/ca_cert.pem", cluster.TempDir))
cacheLogger := logging.NewVaultLogger(hclog.Trace).Named("cache")
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
// Create the API proxier
apiProxy := NewAPIProxy(&APIProxyConfig{
Logger: cacheLogger.Named("apiproxy"),
})
// Create the lease cache proxier and set its underlying proxier to
// the API proxier.
leaseCache, err := NewLeaseCache(&LeaseCacheConfig{
BaseContext: ctx,
Proxier: apiProxy,
Logger: cacheLogger.Named("leasecache"),
})
if err != nil {
t.Fatal(err)
}
// Create a muxer and add paths relevant for the lease cache layer
mux := http.NewServeMux()
mux.Handle("/v1/agent/cache-clear", leaseCache.HandleCacheClear(ctx))
mux.Handle("/", Handler(ctx, cacheLogger, leaseCache, false, clusterClient))
server := &http.Server{
Handler: mux,
ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 30 * time.Second,
IdleTimeout: 5 * time.Minute,
ErrorLog: cacheLogger.StandardLogger(nil),
}
go server.Serve(listener)
// testClient is the client that is used to talk to the agent for proxying/caching behavior.
testClient, err := clusterClient.Clone()
if err != nil {
t.Fatal(err)
}
if err := testClient.SetAddress("http://" + listener.Addr().String()); err != nil {
t.Fatal(err)
}
// Login via userpass method to derive a managed token. Set that token as the
// testClient's token
resp, err := testClient.Logical().Write("auth/userpass/login/foo", map[string]interface{}{
"password": "bar",
})
if err != nil {
t.Fatal(err)
}
testClient.SetToken(resp.Auth.ClientToken)
cleanup := func() {
cluster.Cleanup()
os.Setenv(api.EnvVaultAddress, origEnvVaultAddress)
os.Setenv(api.EnvVaultCACert, origEnvVaultCACert)
listener.Close()
}
return cleanup, clusterClient, testClient, leaseCache
}
func tokenRevocationValidation(t *testing.T, sampleSpace map[string]string, expected map[string]string, leaseCache *LeaseCache) {
t.Helper()
for val, valType := range sampleSpace {
index, err := leaseCache.db.Get(valType, val)
if err != nil {
t.Fatal(err)
}
if expected[val] == "" && index != nil {
t.Fatalf("failed to evict index from the cache: type: %q, value: %q", valType, val)
}
if expected[val] != "" && index == nil {
t.Fatalf("evicted an undesired index from cache: type: %q, value: %q", valType, val)
}
}
}
func TestCache_TokenRevocations_RevokeOrphan(t *testing.T) {
coreConfig := &vault.CoreConfig{
DisableMlock: true,
DisableCache: true,
Logger: hclog.NewNullLogger(),
LogicalBackends: map[string]logical.Factory{
"kv": vault.LeasedPassthroughBackendFactory,
},
}
sampleSpace := make(map[string]string)
cleanup, _, testClient, leaseCache := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig)
defer cleanup()
token1 := testClient.Token()
sampleSpace[token1] = "token"
// Mount the kv backend
err := testClient.Sys().Mount("kv", &api.MountInput{
Type: "kv",
})
if err != nil {
t.Fatal(err)
}
// Create a secret in the backend
_, err = testClient.Logical().Write("kv/foo", map[string]interface{}{
"value": "bar",
"ttl": "1h",
})
if err != nil {
t.Fatal(err)
}
// Read the secret and create a lease
leaseResp, err := testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease1 := leaseResp.LeaseID
sampleSpace[lease1] = "lease"
resp, err := testClient.Logical().Write("auth/token/create", nil)
if err != nil {
t.Fatal(err)
}
token2 := resp.Auth.ClientToken
sampleSpace[token2] = "token"
testClient.SetToken(token2)
leaseResp, err = testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease2 := leaseResp.LeaseID
sampleSpace[lease2] = "lease"
resp, err = testClient.Logical().Write("auth/token/create", nil)
if err != nil {
t.Fatal(err)
}
token3 := resp.Auth.ClientToken
sampleSpace[token3] = "token"
testClient.SetToken(token3)
leaseResp, err = testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease3 := leaseResp.LeaseID
sampleSpace[lease3] = "lease"
expected := make(map[string]string)
for k, v := range sampleSpace {
expected[k] = v
}
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
// Revoke-orphan the intermediate token. This should result in its own
// eviction and evictions of the revoked token's leases. All other things
// including the child tokens and leases of the child tokens should be
// untouched.
testClient.SetToken(token2)
err = testClient.Auth().Token().RevokeOrphan(token2)
if err != nil {
t.Fatal(err)
}
time.Sleep(1 * time.Second)
expected = map[string]string{
token1: "token",
lease1: "lease",
token3: "token",
lease3: "lease",
}
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
}
func TestCache_TokenRevocations_LeafLevelToken(t *testing.T) {
coreConfig := &vault.CoreConfig{
DisableMlock: true,
DisableCache: true,
Logger: hclog.NewNullLogger(),
LogicalBackends: map[string]logical.Factory{
"kv": vault.LeasedPassthroughBackendFactory,
},
}
sampleSpace := make(map[string]string)
cleanup, _, testClient, leaseCache := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig)
defer cleanup()
token1 := testClient.Token()
sampleSpace[token1] = "token"
// Mount the kv backend
err := testClient.Sys().Mount("kv", &api.MountInput{
Type: "kv",
})
if err != nil {
t.Fatal(err)
}
// Create a secret in the backend
_, err = testClient.Logical().Write("kv/foo", map[string]interface{}{
"value": "bar",
"ttl": "1h",
})
if err != nil {
t.Fatal(err)
}
// Read the secret and create a lease
leaseResp, err := testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease1 := leaseResp.LeaseID
sampleSpace[lease1] = "lease"
resp, err := testClient.Logical().Write("auth/token/create", nil)
if err != nil {
t.Fatal(err)
}
token2 := resp.Auth.ClientToken
sampleSpace[token2] = "token"
testClient.SetToken(token2)
leaseResp, err = testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease2 := leaseResp.LeaseID
sampleSpace[lease2] = "lease"
resp, err = testClient.Logical().Write("auth/token/create", nil)
if err != nil {
t.Fatal(err)
}
token3 := resp.Auth.ClientToken
sampleSpace[token3] = "token"
testClient.SetToken(token3)
leaseResp, err = testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease3 := leaseResp.LeaseID
sampleSpace[lease3] = "lease"
expected := make(map[string]string)
for k, v := range sampleSpace {
expected[k] = v
}
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
// Revoke the lef token. This should evict all the leases belonging to this
// token, evict entries for all the child tokens and their respective
// leases.
testClient.SetToken(token3)
err = testClient.Auth().Token().RevokeSelf("")
if err != nil {
t.Fatal(err)
}
time.Sleep(1 * time.Second)
expected = map[string]string{
token1: "token",
lease1: "lease",
token2: "token",
lease2: "lease",
}
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
}
func TestCache_TokenRevocations_IntermediateLevelToken(t *testing.T) {
coreConfig := &vault.CoreConfig{
DisableMlock: true,
DisableCache: true,
Logger: hclog.NewNullLogger(),
LogicalBackends: map[string]logical.Factory{
"kv": vault.LeasedPassthroughBackendFactory,
},
}
sampleSpace := make(map[string]string)
cleanup, _, testClient, leaseCache := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig)
defer cleanup()
token1 := testClient.Token()
sampleSpace[token1] = "token"
// Mount the kv backend
err := testClient.Sys().Mount("kv", &api.MountInput{
Type: "kv",
})
if err != nil {
t.Fatal(err)
}
// Create a secret in the backend
_, err = testClient.Logical().Write("kv/foo", map[string]interface{}{
"value": "bar",
"ttl": "1h",
})
if err != nil {
t.Fatal(err)
}
// Read the secret and create a lease
leaseResp, err := testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease1 := leaseResp.LeaseID
sampleSpace[lease1] = "lease"
resp, err := testClient.Logical().Write("auth/token/create", nil)
if err != nil {
t.Fatal(err)
}
token2 := resp.Auth.ClientToken
sampleSpace[token2] = "token"
testClient.SetToken(token2)
leaseResp, err = testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease2 := leaseResp.LeaseID
sampleSpace[lease2] = "lease"
resp, err = testClient.Logical().Write("auth/token/create", nil)
if err != nil {
t.Fatal(err)
}
token3 := resp.Auth.ClientToken
sampleSpace[token3] = "token"
testClient.SetToken(token3)
leaseResp, err = testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease3 := leaseResp.LeaseID
sampleSpace[lease3] = "lease"
expected := make(map[string]string)
for k, v := range sampleSpace {
expected[k] = v
}
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
// Revoke the second level token. This should evict all the leases
// belonging to this token, evict entries for all the child tokens and
// their respective leases.
testClient.SetToken(token2)
err = testClient.Auth().Token().RevokeSelf("")
if err != nil {
t.Fatal(err)
}
time.Sleep(1 * time.Second)
expected = map[string]string{
token1: "token",
lease1: "lease",
}
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
}
func TestCache_TokenRevocations_TopLevelToken(t *testing.T) {
coreConfig := &vault.CoreConfig{
DisableMlock: true,
DisableCache: true,
Logger: hclog.NewNullLogger(),
LogicalBackends: map[string]logical.Factory{
"kv": vault.LeasedPassthroughBackendFactory,
},
}
sampleSpace := make(map[string]string)
cleanup, _, testClient, leaseCache := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig)
defer cleanup()
token1 := testClient.Token()
sampleSpace[token1] = "token"
// Mount the kv backend
err := testClient.Sys().Mount("kv", &api.MountInput{
Type: "kv",
})
if err != nil {
t.Fatal(err)
}
// Create a secret in the backend
_, err = testClient.Logical().Write("kv/foo", map[string]interface{}{
"value": "bar",
"ttl": "1h",
})
if err != nil {
t.Fatal(err)
}
// Read the secret and create a lease
leaseResp, err := testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease1 := leaseResp.LeaseID
sampleSpace[lease1] = "lease"
resp, err := testClient.Logical().Write("auth/token/create", nil)
if err != nil {
t.Fatal(err)
}
token2 := resp.Auth.ClientToken
sampleSpace[token2] = "token"
testClient.SetToken(token2)
leaseResp, err = testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease2 := leaseResp.LeaseID
sampleSpace[lease2] = "lease"
resp, err = testClient.Logical().Write("auth/token/create", nil)
if err != nil {
t.Fatal(err)
}
token3 := resp.Auth.ClientToken
sampleSpace[token3] = "token"
testClient.SetToken(token3)
leaseResp, err = testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease3 := leaseResp.LeaseID
sampleSpace[lease3] = "lease"
expected := make(map[string]string)
for k, v := range sampleSpace {
expected[k] = v
}
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
// Revoke the top level token. This should evict all the leases belonging
// to this token, evict entries for all the child tokens and their
// respective leases.
testClient.SetToken(token1)
err = testClient.Auth().Token().RevokeSelf("")
if err != nil {
t.Fatal(err)
}
time.Sleep(1 * time.Second)
expected = make(map[string]string)
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
}
func TestCache_TokenRevocations_Shutdown(t *testing.T) {
coreConfig := &vault.CoreConfig{
DisableMlock: true,
DisableCache: true,
Logger: hclog.NewNullLogger(),
LogicalBackends: map[string]logical.Factory{
"kv": vault.LeasedPassthroughBackendFactory,
},
}
sampleSpace := make(map[string]string)
ctx, rootCancelFunc := context.WithCancel(namespace.RootContext(nil))
cleanup, _, testClient, leaseCache := setupClusterAndAgent(ctx, t, coreConfig)
defer cleanup()
token1 := testClient.Token()
sampleSpace[token1] = "token"
// Mount the kv backend
err := testClient.Sys().Mount("kv", &api.MountInput{
Type: "kv",
})
if err != nil {
t.Fatal(err)
}
// Create a secret in the backend
_, err = testClient.Logical().Write("kv/foo", map[string]interface{}{
"value": "bar",
"ttl": "1h",
})
if err != nil {
t.Fatal(err)
}
// Read the secret and create a lease
leaseResp, err := testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease1 := leaseResp.LeaseID
sampleSpace[lease1] = "lease"
resp, err := testClient.Logical().Write("auth/token/create", nil)
if err != nil {
t.Fatal(err)
}
token2 := resp.Auth.ClientToken
sampleSpace[token2] = "token"
testClient.SetToken(token2)
leaseResp, err = testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease2 := leaseResp.LeaseID
sampleSpace[lease2] = "lease"
resp, err = testClient.Logical().Write("auth/token/create", nil)
if err != nil {
t.Fatal(err)
}
token3 := resp.Auth.ClientToken
sampleSpace[token3] = "token"
testClient.SetToken(token3)
leaseResp, err = testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease3 := leaseResp.LeaseID
sampleSpace[lease3] = "lease"
expected := make(map[string]string)
for k, v := range sampleSpace {
expected[k] = v
}
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
rootCancelFunc()
time.Sleep(1 * time.Second)
// Ensure that all the entries are now gone
expected = make(map[string]string)
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
}
func TestCache_TokenRevocations_BaseContextCancellation(t *testing.T) {
coreConfig := &vault.CoreConfig{
DisableMlock: true,
DisableCache: true,
Logger: hclog.NewNullLogger(),
LogicalBackends: map[string]logical.Factory{
"kv": vault.LeasedPassthroughBackendFactory,
},
}
sampleSpace := make(map[string]string)
cleanup, _, testClient, leaseCache := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig)
defer cleanup()
token1 := testClient.Token()
sampleSpace[token1] = "token"
// Mount the kv backend
err := testClient.Sys().Mount("kv", &api.MountInput{
Type: "kv",
})
if err != nil {
t.Fatal(err)
}
// Create a secret in the backend
_, err = testClient.Logical().Write("kv/foo", map[string]interface{}{
"value": "bar",
"ttl": "1h",
})
if err != nil {
t.Fatal(err)
}
// Read the secret and create a lease
leaseResp, err := testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease1 := leaseResp.LeaseID
sampleSpace[lease1] = "lease"
resp, err := testClient.Logical().Write("auth/token/create", nil)
if err != nil {
t.Fatal(err)
}
token2 := resp.Auth.ClientToken
sampleSpace[token2] = "token"
testClient.SetToken(token2)
leaseResp, err = testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease2 := leaseResp.LeaseID
sampleSpace[lease2] = "lease"
resp, err = testClient.Logical().Write("auth/token/create", nil)
if err != nil {
t.Fatal(err)
}
token3 := resp.Auth.ClientToken
sampleSpace[token3] = "token"
testClient.SetToken(token3)
leaseResp, err = testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease3 := leaseResp.LeaseID
sampleSpace[lease3] = "lease"
expected := make(map[string]string)
for k, v := range sampleSpace {
expected[k] = v
}
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
// Cancel the base context of the lease cache. This should trigger
// evictions of all the entries from the cache.
leaseCache.baseCtxInfo.CancelFunc()
time.Sleep(1 * time.Second)
// Ensure that all the entries are now gone
expected = make(map[string]string)
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
}
func TestCache_NonCacheable(t *testing.T) {
coreConfig := &vault.CoreConfig{
DisableMlock: true,
DisableCache: true,
Logger: hclog.NewNullLogger(),
LogicalBackends: map[string]logical.Factory{
"kv": kv.Factory,
},
}
cleanup, _, testClient, _ := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig)
defer cleanup()
// Query mounts first
origMounts, err := testClient.Sys().ListMounts()
if err != nil {
t.Fatal(err)
}
// Mount a kv backend
if err := testClient.Sys().Mount("kv", &api.MountInput{
Type: "kv",
Options: map[string]string{
"version": "2",
},
}); err != nil {
t.Fatal(err)
}
// Query mounts again
newMounts, err := testClient.Sys().ListMounts()
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(origMounts, newMounts); diff == nil {
t.Logf("response #1: %#v", origMounts)
t.Logf("response #2: %#v", newMounts)
t.Fatal("expected requests to be not cached")
}
}
func TestCache_AuthResponse(t *testing.T) {
cleanup, _, testClient, _ := setupClusterAndAgent(namespace.RootContext(nil), t, nil)
defer cleanup()
resp, err := testClient.Logical().Write("auth/token/create", nil)
if err != nil {
t.Fatal(err)
}
token := resp.Auth.ClientToken
testClient.SetToken(token)
authTokeCreateReq := func(t *testing.T, policies map[string]interface{}) *api.Secret {
resp, err := testClient.Logical().Write("auth/token/create", policies)
if err != nil {
t.Fatal(err)
}
if resp.Auth == nil || resp.Auth.ClientToken == "" {
t.Fatalf("expected a valid client token in the response, got = %#v", resp)
}
return resp
}
// Test on auth response by creating a child token
{
proxiedResp := authTokeCreateReq(t, map[string]interface{}{
"policies": "default",
})
cachedResp := authTokeCreateReq(t, map[string]interface{}{
"policies": "default",
})
if diff := deep.Equal(proxiedResp.Auth.ClientToken, cachedResp.Auth.ClientToken); diff != nil {
t.Fatal(diff)
}
}
// Test on *non-renewable* auth response by creating a child root token
{
proxiedResp := authTokeCreateReq(t, nil)
cachedResp := authTokeCreateReq(t, nil)
if diff := deep.Equal(proxiedResp.Auth.ClientToken, cachedResp.Auth.ClientToken); diff != nil {
t.Fatal(diff)
}
}
}
func TestCache_LeaseResponse(t *testing.T) {
coreConfig := &vault.CoreConfig{
DisableMlock: true,
DisableCache: true,
Logger: hclog.NewNullLogger(),
LogicalBackends: map[string]logical.Factory{
"kv": vault.LeasedPassthroughBackendFactory,
},
}
cleanup, client, testClient, _ := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig)
defer cleanup()
err := client.Sys().Mount("kv", &api.MountInput{
Type: "kv",
})
if err != nil {
t.Fatal(err)
}
// Test proxy by issuing two different requests
{
// Write data to the lease-kv backend
_, err := testClient.Logical().Write("kv/foo", map[string]interface{}{
"value": "bar",
"ttl": "1h",
})
if err != nil {
t.Fatal(err)
}
_, err = testClient.Logical().Write("kv/foobar", map[string]interface{}{
"value": "bar",
"ttl": "1h",
})
if err != nil {
t.Fatal(err)
}
firstResp, err := testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
secondResp, err := testClient.Logical().Read("kv/foobar")
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(firstResp, secondResp); diff == nil {
t.Logf("response: %#v", firstResp)
t.Fatal("expected proxied responses, got cached response on second request")
}
}
// Test caching behavior by issue the same request twice
{
_, err := testClient.Logical().Write("kv/baz", map[string]interface{}{
"value": "foo",
"ttl": "1h",
})
if err != nil {
t.Fatal(err)
}
proxiedResp, err := testClient.Logical().Read("kv/baz")
if err != nil {
t.Fatal(err)
}
cachedResp, err := testClient.Logical().Read("kv/baz")
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(proxiedResp, cachedResp); diff != nil {
t.Fatal(diff)
}
}
}

View File

@ -0,0 +1,265 @@
package cachememdb
import (
"errors"
"fmt"
memdb "github.com/hashicorp/go-memdb"
)
const (
tableNameIndexer = "indexer"
)
// CacheMemDB is the underlying cache database for storing indexes.
type CacheMemDB struct {
db *memdb.MemDB
}
// New creates a new instance of CacheMemDB.
func New() (*CacheMemDB, error) {
db, err := newDB()
if err != nil {
return nil, err
}
return &CacheMemDB{
db: db,
}, nil
}
func newDB() (*memdb.MemDB, error) {
cacheSchema := &memdb.DBSchema{
Tables: map[string]*memdb.TableSchema{
tableNameIndexer: &memdb.TableSchema{
Name: tableNameIndexer,
Indexes: map[string]*memdb.IndexSchema{
// This index enables fetching the cached item based on the
// identifier of the index.
IndexNameID: &memdb.IndexSchema{
Name: IndexNameID,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "ID",
},
},
// This index enables fetching all the entries in cache for
// a given request path, in a given namespace.
IndexNameRequestPath: &memdb.IndexSchema{
Name: IndexNameRequestPath,
Unique: false,
Indexer: &memdb.CompoundIndex{
Indexes: []memdb.Indexer{
&memdb.StringFieldIndex{
Field: "Namespace",
},
&memdb.StringFieldIndex{
Field: "RequestPath",
},
},
},
},
// This index enables fetching all the entries in cache
// belonging to the leases of a given token.
IndexNameLeaseToken: &memdb.IndexSchema{
Name: IndexNameLeaseToken,
Unique: false,
AllowMissing: true,
Indexer: &memdb.StringFieldIndex{
Field: "LeaseToken",
},
},
// This index enables fetching all the entries in cache
// that are tied to the given token, regardless of the
// entries belonging to the token or belonging to the
// lease.
IndexNameToken: &memdb.IndexSchema{
Name: IndexNameToken,
Unique: true,
AllowMissing: true,
Indexer: &memdb.StringFieldIndex{
Field: "Token",
},
},
// This index enables fetching all the entries in cache for
// the given parent token.
IndexNameTokenParent: &memdb.IndexSchema{
Name: IndexNameTokenParent,
Unique: false,
AllowMissing: true,
Indexer: &memdb.StringFieldIndex{
Field: "TokenParent",
},
},
// This index enables fetching all the entries in cache for
// the given accessor.
IndexNameTokenAccessor: &memdb.IndexSchema{
Name: IndexNameTokenAccessor,
Unique: true,
AllowMissing: true,
Indexer: &memdb.StringFieldIndex{
Field: "TokenAccessor",
},
},
// This index enables fetching all the entries in cache for
// the given lease identifier.
IndexNameLease: &memdb.IndexSchema{
Name: IndexNameLease,
Unique: true,
AllowMissing: true,
Indexer: &memdb.StringFieldIndex{
Field: "Lease",
},
},
},
},
},
}
db, err := memdb.NewMemDB(cacheSchema)
if err != nil {
return nil, err
}
return db, nil
}
// Get returns the index based on the indexer and the index values provided.
func (c *CacheMemDB) Get(indexName string, indexValues ...interface{}) (*Index, error) {
if !validIndexName(indexName) {
return nil, fmt.Errorf("invalid index name %q", indexName)
}
raw, err := c.db.Txn(false).First(tableNameIndexer, indexName, indexValues...)
if err != nil {
return nil, err
}
if raw == nil {
return nil, nil
}
index, ok := raw.(*Index)
if !ok {
return nil, errors.New("unable to parse index value from the cache")
}
return index, nil
}
// Set stores the index into the cache.
func (c *CacheMemDB) Set(index *Index) error {
if index == nil {
return errors.New("nil index provided")
}
txn := c.db.Txn(true)
defer txn.Abort()
if err := txn.Insert(tableNameIndexer, index); err != nil {
return fmt.Errorf("unable to insert index into cache: %v", err)
}
txn.Commit()
return nil
}
// GetByPrefix returns all the cached indexes based on the index name and the
// value prefix.
func (c *CacheMemDB) GetByPrefix(indexName string, indexValues ...interface{}) ([]*Index, error) {
if !validIndexName(indexName) {
return nil, fmt.Errorf("invalid index name %q", indexName)
}
indexName = indexName + "_prefix"
// Get all the objects
iter, err := c.db.Txn(false).Get(tableNameIndexer, indexName, indexValues...)
if err != nil {
return nil, err
}
var indexes []*Index
for {
obj := iter.Next()
if obj == nil {
break
}
index, ok := obj.(*Index)
if !ok {
return nil, fmt.Errorf("failed to cast cached index")
}
indexes = append(indexes, index)
}
return indexes, nil
}
// Evict removes an index from the cache based on index name and value.
func (c *CacheMemDB) Evict(indexName string, indexValues ...interface{}) error {
index, err := c.Get(indexName, indexValues...)
if err != nil {
return fmt.Errorf("unable to fetch index on cache deletion: %v", err)
}
if index == nil {
return nil
}
txn := c.db.Txn(true)
defer txn.Abort()
if err := txn.Delete(tableNameIndexer, index); err != nil {
return fmt.Errorf("unable to delete index from cache: %v", err)
}
txn.Commit()
return nil
}
// EvictAll removes all matching indexes from the cache based on index name and value.
func (c *CacheMemDB) EvictAll(indexName, indexValue string) error {
return c.batchEvict(false, indexName, indexValue)
}
// EvictByPrefix removes all matching prefix indexes from the cache based on index name and prefix.
func (c *CacheMemDB) EvictByPrefix(indexName, indexPrefix string) error {
return c.batchEvict(true, indexName, indexPrefix)
}
// batchEvict is a helper that supports eviction based on absolute and prefixed index values.
func (c *CacheMemDB) batchEvict(isPrefix bool, indexName string, indexValues ...interface{}) error {
if !validIndexName(indexName) {
return fmt.Errorf("invalid index name %q", indexName)
}
if isPrefix {
indexName = indexName + "_prefix"
}
txn := c.db.Txn(true)
defer txn.Abort()
_, err := txn.DeleteAll(tableNameIndexer, indexName, indexValues...)
if err != nil {
return err
}
txn.Commit()
return nil
}
// Flush resets the underlying cache object.
func (c *CacheMemDB) Flush() error {
newDB, err := newDB()
if err != nil {
return err
}
c.db = newDB
return nil
}

View File

@ -0,0 +1,388 @@
package cachememdb
import (
"context"
"testing"
"github.com/go-test/deep"
)
func testContextInfo() *ContextInfo {
ctx, cancelFunc := context.WithCancel(context.Background())
return &ContextInfo{
Ctx: ctx,
CancelFunc: cancelFunc,
}
}
func TestNew(t *testing.T) {
_, err := New()
if err != nil {
t.Fatal(err)
}
}
func TestCacheMemDB_Get(t *testing.T) {
cache, err := New()
if err != nil {
t.Fatal(err)
}
// Test invalid index name
_, err = cache.Get("foo", "bar")
if err == nil {
t.Fatal("expected error")
}
// Test on empty cache
index, err := cache.Get(IndexNameID, "foo")
if err != nil {
t.Fatal(err)
}
if index != nil {
t.Fatalf("expected nil index, got: %v", index)
}
// Populate cache
in := &Index{
ID: "test_id",
Namespace: "test_ns/",
RequestPath: "/v1/request/path",
Token: "test_token",
TokenAccessor: "test_accessor",
Lease: "test_lease",
Response: []byte("hello world"),
}
if err := cache.Set(in); err != nil {
t.Fatal(err)
}
testCases := []struct {
name string
indexName string
indexValues []interface{}
}{
{
"by_index_id",
"id",
[]interface{}{in.ID},
},
{
"by_request_path",
"request_path",
[]interface{}{in.Namespace, in.RequestPath},
},
{
"by_lease",
"lease",
[]interface{}{in.Lease},
},
{
"by_token",
"token",
[]interface{}{in.Token},
},
{
"by_token_accessor",
"token_accessor",
[]interface{}{in.TokenAccessor},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
out, err := cache.Get(tc.indexName, tc.indexValues...)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(in, out); diff != nil {
t.Fatal(diff)
}
})
}
}
func TestCacheMemDB_GetByPrefix(t *testing.T) {
cache, err := New()
if err != nil {
t.Fatal(err)
}
// Test invalid index name
_, err = cache.GetByPrefix("foo", "bar", "baz")
if err == nil {
t.Fatal("expected error")
}
// Test on empty cache
index, err := cache.GetByPrefix(IndexNameRequestPath, "foo", "bar")
if err != nil {
t.Fatal(err)
}
if index != nil {
t.Fatalf("expected nil index, got: %v", index)
}
// Populate cache
in := &Index{
ID: "test_id",
Namespace: "test_ns/",
RequestPath: "/v1/request/path/1",
Token: "test_token",
TokenAccessor: "test_accessor",
Lease: "path/to/test_lease/1",
Response: []byte("hello world"),
}
if err := cache.Set(in); err != nil {
t.Fatal(err)
}
// Populate cache
in2 := &Index{
ID: "test_id_2",
Namespace: "test_ns/",
RequestPath: "/v1/request/path/2",
Token: "test_token",
TokenAccessor: "test_accessor",
Lease: "path/to/test_lease/2",
Response: []byte("hello world"),
}
if err := cache.Set(in2); err != nil {
t.Fatal(err)
}
testCases := []struct {
name string
indexName string
indexValues []interface{}
}{
{
"by_request_path",
"request_path",
[]interface{}{"test_ns/", "/v1/request/path"},
},
{
"by_lease",
"lease",
[]interface{}{"path/to/test_lease"},
},
{
"by_token",
"token",
[]interface{}{"test_token"},
},
{
"by_token_accessor",
"token_accessor",
[]interface{}{"test_accessor"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
out, err := cache.GetByPrefix(tc.indexName, tc.indexValues...)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal([]*Index{in, in2}, out); diff != nil {
t.Fatal(diff)
}
})
}
}
func TestCacheMemDB_Set(t *testing.T) {
cache, err := New()
if err != nil {
t.Fatal(err)
}
testCases := []struct {
name string
index *Index
wantErr bool
}{
{
"nil",
nil,
true,
},
{
"empty_fields",
&Index{},
true,
},
{
"missing_required_fields",
&Index{
Lease: "foo",
},
true,
},
{
"all_fields",
&Index{
ID: "test_id",
Namespace: "test_ns/",
RequestPath: "/v1/request/path",
Token: "test_token",
TokenAccessor: "test_accessor",
Lease: "test_lease",
RenewCtxInfo: testContextInfo(),
},
false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if err := cache.Set(tc.index); (err != nil) != tc.wantErr {
t.Fatalf("CacheMemDB.Set() error = %v, wantErr = %v", err, tc.wantErr)
}
})
}
}
func TestCacheMemDB_Evict(t *testing.T) {
cache, err := New()
if err != nil {
t.Fatal(err)
}
// Test on empty cache
if err := cache.Evict(IndexNameID, "foo"); err != nil {
t.Fatal(err)
}
testIndex := &Index{
ID: "test_id",
Namespace: "test_ns/",
RequestPath: "/v1/request/path",
Token: "test_token",
TokenAccessor: "test_token_accessor",
Lease: "test_lease",
RenewCtxInfo: testContextInfo(),
}
testCases := []struct {
name string
indexName string
indexValues []interface{}
insertIndex *Index
wantErr bool
}{
{
"empty_params",
"",
[]interface{}{""},
nil,
true,
},
{
"invalid_params",
"foo",
[]interface{}{"bar"},
nil,
true,
},
{
"by_id",
"id",
[]interface{}{"test_id"},
testIndex,
false,
},
{
"by_request_path",
"request_path",
[]interface{}{"test_ns/", "/v1/request/path"},
testIndex,
false,
},
{
"by_token",
"token",
[]interface{}{"test_token"},
testIndex,
false,
},
{
"by_token_accessor",
"token_accessor",
[]interface{}{"test_accessor"},
testIndex,
false,
},
{
"by_lease",
"lease",
[]interface{}{"test_lease"},
testIndex,
false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.insertIndex != nil {
if err := cache.Set(tc.insertIndex); err != nil {
t.Fatal(err)
}
}
if err := cache.Evict(tc.indexName, tc.indexValues...); (err != nil) != tc.wantErr {
t.Fatal(err)
}
// Verify that the cache doesn't contain the entry any more
index, err := cache.Get(tc.indexName, tc.indexValues...)
if (err != nil) != tc.wantErr {
t.Fatal(err)
}
if index != nil {
t.Fatalf("expected nil entry, got = %#v", index)
}
})
}
}
func TestCacheMemDB_Flush(t *testing.T) {
cache, err := New()
if err != nil {
t.Fatal(err)
}
// Populate cache
in := &Index{
ID: "test_id",
Token: "test_token",
Lease: "test_lease",
Namespace: "test_ns/",
RequestPath: "/v1/request/path",
Response: []byte("hello world"),
}
if err := cache.Set(in); err != nil {
t.Fatal(err)
}
// Reset the cache
if err := cache.Flush(); err != nil {
t.Fatal(err)
}
// Check the cache doesn't contain inserted index
out, err := cache.Get(IndexNameID, "test_id")
if err != nil {
t.Fatal(err)
}
if out != nil {
t.Fatalf("expected cache to be empty, got = %v", out)
}
}

97
command/agent/cache/cachememdb/index.go vendored Normal file
View File

@ -0,0 +1,97 @@
package cachememdb
import "context"
type ContextInfo struct {
Ctx context.Context
CancelFunc context.CancelFunc
DoneCh chan struct{}
}
// Index holds the response to be cached along with multiple other values that
// serve as pointers to refer back to this index.
type Index struct {
// ID is a value that uniquely represents the request held by this
// index. This is computed by serializing and hashing the response object.
// Required: true, Unique: true
ID string
// Token is the token that fetched the response held by this index
// Required: true, Unique: true
Token string
// TokenParent is the parent token of the token held by this index
// Required: false, Unique: false
TokenParent string
// TokenAccessor is the accessor of the token being cached in this index
// Required: true, Unique: true
TokenAccessor string
// Namespace is the namespace that was provided in the request path as the
// Vault namespace to query
Namespace string
// RequestPath is the path of the request that resulted in the response
// held by this index.
// Required: true, Unique: false
RequestPath string
// Lease is the identifier of the lease in Vault, that belongs to the
// response held by this index.
// Required: false, Unique: true
Lease string
// LeaseToken is the identifier of the token that created the lease held by
// this index.
// Required: false, Unique: false
LeaseToken string
// Response is the serialized response object that the agent is caching.
Response []byte
// RenewCtxInfo holds the context and the corresponding cancel func for the
// goroutine that manages the renewal of the secret belonging to the
// response in this index.
RenewCtxInfo *ContextInfo
}
type IndexName uint32
const (
// IndexNameID is the ID of the index constructed from the serialized request.
IndexNameID = "id"
// IndexNameLease is the lease of the index.
IndexNameLease = "lease"
// IndexNameRequestPath is the request path of the index.
IndexNameRequestPath = "request_path"
// IndexNameToken is the token of the index.
IndexNameToken = "token"
// IndexNameTokenAccessor is the token accessor of the index.
IndexNameTokenAccessor = "token_accessor"
// IndexNameTokenParent is the token parent of the index.
IndexNameTokenParent = "token_parent"
// IndexNameLeaseToken is the token that created the lease.
IndexNameLeaseToken = "lease_token"
)
func validIndexName(indexName string) bool {
switch indexName {
case "id":
case "lease":
case "request_path":
case "token":
case "token_accessor":
case "token_parent":
case "lease_token":
default:
return false
}
return true
}

155
command/agent/cache/handler.go vendored Normal file
View File

@ -0,0 +1,155 @@
package cache
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"github.com/hashicorp/errwrap"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/consts"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/logical"
)
func Handler(ctx context.Context, logger hclog.Logger, proxier Proxier, useAutoAuthToken bool, client *api.Client) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger.Info("received request", "path", r.URL.Path, "method", r.Method)
token := r.Header.Get(consts.AuthHeaderName)
if token == "" && useAutoAuthToken {
logger.Debug("using auto auth token")
token = client.Token()
}
// Parse and reset body.
reqBody, err := ioutil.ReadAll(r.Body)
if err != nil {
logger.Error("failed to read request body")
respondError(w, http.StatusInternalServerError, errors.New("failed to read request body"))
}
if r.Body != nil {
r.Body.Close()
}
r.Body = ioutil.NopCloser(bytes.NewBuffer(reqBody))
req := &SendRequest{
Token: token,
Request: r,
RequestBody: reqBody,
}
resp, err := proxier.Send(ctx, req)
if err != nil {
respondError(w, http.StatusInternalServerError, errwrap.Wrapf("failed to get the response: {{err}}", err))
return
}
err = processTokenLookupResponse(ctx, logger, useAutoAuthToken, client, req, resp)
if err != nil {
respondError(w, http.StatusInternalServerError, errwrap.Wrapf("failed to process token lookup response: {{err}}", err))
return
}
defer resp.Response.Body.Close()
copyHeader(w.Header(), resp.Response.Header)
w.WriteHeader(resp.Response.StatusCode)
io.Copy(w, resp.Response.Body)
return
})
}
// processTokenLookupResponse checks if the request was one of token
// lookup-self. If the auto-auth token was used to perform lookup-self, the
// identifier of the token and its accessor same will be stripped off of the
// response.
func processTokenLookupResponse(ctx context.Context, logger hclog.Logger, useAutoAuthToken bool, client *api.Client, req *SendRequest, resp *SendResponse) error {
// If auto-auth token is not being used, there is nothing to do.
if !useAutoAuthToken {
return nil
}
// If lookup responded with non 200 status, there is nothing to do.
if resp.Response.StatusCode != http.StatusOK {
return nil
}
// Strip-off namespace related information from the request and get the
// relative path of the request.
_, path := deriveNamespaceAndRevocationPath(req)
if path == vaultPathTokenLookupSelf {
logger.Info("stripping auto-auth token from the response", "path", req.Request.URL.Path, "method", req.Request.Method)
secret, err := api.ParseSecret(bytes.NewBuffer(resp.ResponseBody))
if err != nil {
return fmt.Errorf("failed to parse token lookup response: %v", err)
}
if secret != nil && secret.Data != nil && secret.Data["id"] != nil {
token, ok := secret.Data["id"].(string)
if !ok {
return fmt.Errorf("failed to type assert the token id in the response")
}
if token == client.Token() {
delete(secret.Data, "id")
delete(secret.Data, "accessor")
}
bodyBytes, err := json.Marshal(secret)
if err != nil {
return err
}
if resp.Response.Body != nil {
resp.Response.Body.Close()
}
resp.Response.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes))
resp.Response.ContentLength = int64(len(bodyBytes))
// Serialize and re-read the reponse
var respBytes bytes.Buffer
err = resp.Response.Write(&respBytes)
if err != nil {
return fmt.Errorf("failed to serialize the updated response: %v", err)
}
updatedResponse, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respBytes.Bytes())), nil)
if err != nil {
return fmt.Errorf("failed to deserialize the updated response: %v", err)
}
resp.Response = &api.Response{
Response: updatedResponse,
}
resp.ResponseBody = bodyBytes
}
}
return nil
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
func respondError(w http.ResponseWriter, status int, err error) {
logical.AdjustErrorStatusCode(&status, err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
resp := &vaulthttp.ErrorResponse{Errors: make([]string, 0, 1)}
if err != nil {
resp.Errors = append(resp.Errors, err.Error())
}
enc := json.NewEncoder(w)
enc.Encode(resp)
}

813
command/agent/cache/lease_cache.go vendored Normal file
View File

@ -0,0 +1,813 @@
package cache
import (
"bufio"
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"strings"
"github.com/hashicorp/errwrap"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
cachememdb "github.com/hashicorp/vault/command/agent/cache/cachememdb"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/helper/namespace"
nshelper "github.com/hashicorp/vault/helper/namespace"
)
const (
vaultPathTokenCreate = "/v1/auth/token/create"
vaultPathTokenRevoke = "/v1/auth/token/revoke"
vaultPathTokenRevokeSelf = "/v1/auth/token/revoke-self"
vaultPathTokenRevokeAccessor = "/v1/auth/token/revoke-accessor"
vaultPathTokenRevokeOrphan = "/v1/auth/token/revoke-orphan"
vaultPathTokenLookupSelf = "/v1/auth/token/lookup-self"
vaultPathLeaseRevoke = "/v1/sys/leases/revoke"
vaultPathLeaseRevokeForce = "/v1/sys/leases/revoke-force"
vaultPathLeaseRevokePrefix = "/v1/sys/leases/revoke-prefix"
)
var (
contextIndexID = contextIndex{}
errInvalidType = errors.New("invalid type provided")
revocationPaths = []string{
strings.TrimPrefix(vaultPathTokenRevoke, "/v1"),
strings.TrimPrefix(vaultPathTokenRevokeSelf, "/v1"),
strings.TrimPrefix(vaultPathTokenRevokeAccessor, "/v1"),
strings.TrimPrefix(vaultPathTokenRevokeOrphan, "/v1"),
strings.TrimPrefix(vaultPathLeaseRevoke, "/v1"),
strings.TrimPrefix(vaultPathLeaseRevokeForce, "/v1"),
strings.TrimPrefix(vaultPathLeaseRevokePrefix, "/v1"),
}
)
type contextIndex struct{}
type cacheClearRequest struct {
Type string `json:"type"`
Value string `json:"value"`
Namespace string `json:"namespace"`
}
// LeaseCache is an implementation of Proxier that handles
// the caching of responses. It passes the incoming request
// to an underlying Proxier implementation.
type LeaseCache struct {
proxier Proxier
logger hclog.Logger
db *cachememdb.CacheMemDB
baseCtxInfo *ContextInfo
}
// LeaseCacheConfig is the configuration for initializing a new
// Lease.
type LeaseCacheConfig struct {
BaseContext context.Context
Proxier Proxier
Logger hclog.Logger
}
// ContextInfo holds a derived context and cancelFunc pair.
type ContextInfo struct {
Ctx context.Context
CancelFunc context.CancelFunc
DoneCh chan struct{}
}
// NewLeaseCache creates a new instance of a LeaseCache.
func NewLeaseCache(conf *LeaseCacheConfig) (*LeaseCache, error) {
if conf == nil {
return nil, errors.New("nil configuration provided")
}
if conf.Proxier == nil || conf.Logger == nil {
return nil, fmt.Errorf("missing configuration required params: %v", conf)
}
db, err := cachememdb.New()
if err != nil {
return nil, err
}
// Create a base context for the lease cache layer
baseCtx, baseCancelFunc := context.WithCancel(conf.BaseContext)
baseCtxInfo := &ContextInfo{
Ctx: baseCtx,
CancelFunc: baseCancelFunc,
}
return &LeaseCache{
proxier: conf.Proxier,
logger: conf.Logger,
db: db,
baseCtxInfo: baseCtxInfo,
}, nil
}
// Send performs a cache lookup on the incoming request. If it's a cache hit,
// it will return the cached response, otherwise it will delegate to the
// underlying Proxier and cache the received response.
func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) {
// Compute the index ID
id, err := computeIndexID(req)
if err != nil {
c.logger.Error("failed to compute cache key", "error", err)
return nil, err
}
// Check if the response for this request is already in the cache
index, err := c.db.Get(cachememdb.IndexNameID, id)
if err != nil {
return nil, err
}
// Cached request is found, deserialize the response and return early
if index != nil {
c.logger.Debug("returning cached response", "path", req.Request.URL.Path)
reader := bufio.NewReader(bytes.NewReader(index.Response))
resp, err := http.ReadResponse(reader, nil)
if err != nil {
c.logger.Error("failed to deserialize response", "error", err)
return nil, err
}
return &SendResponse{
Response: &api.Response{
Response: resp,
},
ResponseBody: index.Response,
}, nil
}
c.logger.Debug("forwarding request", "path", req.Request.URL.Path, "method", req.Request.Method)
// Pass the request down and get a response
resp, err := c.proxier.Send(ctx, req)
if err != nil {
return nil, err
}
// Get the namespace from the request header
namespace := req.Request.Header.Get(consts.NamespaceHeaderName)
// We need to populate an empty value since go-memdb will skip over indexes
// that contain empty values.
if namespace == "" {
namespace = "root/"
}
// Build the index to cache based on the response received
index = &cachememdb.Index{
ID: id,
Namespace: namespace,
RequestPath: req.Request.URL.Path,
}
secret, err := api.ParseSecret(bytes.NewBuffer(resp.ResponseBody))
if err != nil {
c.logger.Error("failed to parse response as secret", "error", err)
return nil, err
}
isRevocation, err := c.handleRevocationRequest(ctx, req, resp)
if err != nil {
c.logger.Error("failed to process the response", "error", err)
return nil, err
}
// If this is a revocation request, do not go through cache logic.
if isRevocation {
return resp, nil
}
// Fast path for responses with no secrets
if secret == nil {
c.logger.Debug("pass-through response; no secret in response", "path", req.Request.URL.Path, "method", req.Request.Method)
return resp, nil
}
// Short-circuit if the secret is not renewable
tokenRenewable, err := secret.TokenIsRenewable()
if err != nil {
c.logger.Error("failed to parse renewable param", "error", err)
return nil, err
}
if !secret.Renewable && !tokenRenewable {
c.logger.Debug("pass-through response; secret not renewable", "path", req.Request.URL.Path, "method", req.Request.Method)
return resp, nil
}
var renewCtxInfo *ContextInfo
switch {
case secret.LeaseID != "":
c.logger.Debug("processing lease response", "path", req.Request.URL.Path, "method", req.Request.Method)
entry, err := c.db.Get(cachememdb.IndexNameToken, req.Token)
if err != nil {
return nil, err
}
// If the lease belongs to a token that is not managed by the agent,
// return the response without caching it.
if entry == nil {
c.logger.Debug("pass-through lease response; token not managed by agent", "path", req.Request.URL.Path, "method", req.Request.Method)
return resp, nil
}
// Derive a context for renewal using the token's context
newCtxInfo := new(ContextInfo)
newCtxInfo.Ctx, newCtxInfo.CancelFunc = context.WithCancel(entry.RenewCtxInfo.Ctx)
newCtxInfo.DoneCh = make(chan struct{})
renewCtxInfo = newCtxInfo
index.Lease = secret.LeaseID
index.LeaseToken = req.Token
case secret.Auth != nil:
c.logger.Debug("processing auth response", "path", req.Request.URL.Path, "method", req.Request.Method)
isNonOrphanNewToken := strings.HasPrefix(req.Request.URL.Path, vaultPathTokenCreate) && resp.Response.StatusCode == http.StatusOK && !secret.Auth.Orphan
// If the new token is a result of token creation endpoints (not from
// login endpoints), and if its a non-orphan, then the new token's
// context should be derived from the context of the parent token.
var parentCtx context.Context
if isNonOrphanNewToken {
entry, err := c.db.Get(cachememdb.IndexNameToken, req.Token)
if err != nil {
return nil, err
}
// If parent token is not managed by the agent, child shouldn't be
// either.
if entry == nil {
c.logger.Debug("pass-through auth response; parent token not managed by agent", "path", req.Request.URL.Path, "method", req.Request.Method)
return resp, nil
}
c.logger.Debug("setting parent context", "path", req.Request.URL.Path, "method", req.Request.Method)
parentCtx = entry.RenewCtxInfo.Ctx
entry.TokenParent = req.Token
}
renewCtxInfo = c.createCtxInfo(parentCtx, secret.Auth.ClientToken)
index.Token = secret.Auth.ClientToken
index.TokenAccessor = secret.Auth.Accessor
default:
// We shouldn't be hitting this, but will err on the side of caution and
// simply proxy.
c.logger.Debug("pass-through response; secret without lease and token", "path", req.Request.URL.Path, "method", req.Request.Method)
return resp, nil
}
// Serialize the response to store it in the cached index
var respBytes bytes.Buffer
err = resp.Response.Write(&respBytes)
if err != nil {
c.logger.Error("failed to serialize response", "error", err)
return nil, err
}
// Reset the response body for upper layers to read
if resp.Response.Body != nil {
resp.Response.Body.Close()
}
resp.Response.Body = ioutil.NopCloser(bytes.NewBuffer(resp.ResponseBody))
// Set the index's Response
index.Response = respBytes.Bytes()
// Store the index ID in the renewer context
renewCtx := context.WithValue(renewCtxInfo.Ctx, contextIndexID, index.ID)
// Store the renewer context in the index
index.RenewCtxInfo = &cachememdb.ContextInfo{
Ctx: renewCtx,
CancelFunc: renewCtxInfo.CancelFunc,
DoneCh: renewCtxInfo.DoneCh,
}
// Store the index in the cache
c.logger.Debug("storing response into the cache", "path", req.Request.URL.Path, "method", req.Request.Method)
err = c.db.Set(index)
if err != nil {
c.logger.Error("failed to cache the proxied response", "error", err)
return nil, err
}
// Start renewing the secret in the response
go c.startRenewing(renewCtx, index, req, secret)
return resp, nil
}
func (c *LeaseCache) createCtxInfo(ctx context.Context, token string) *ContextInfo {
if ctx == nil {
ctx = c.baseCtxInfo.Ctx
}
ctxInfo := new(ContextInfo)
ctxInfo.Ctx, ctxInfo.CancelFunc = context.WithCancel(ctx)
ctxInfo.DoneCh = make(chan struct{})
return ctxInfo
}
func (c *LeaseCache) startRenewing(ctx context.Context, index *cachememdb.Index, req *SendRequest, secret *api.Secret) {
defer func() {
id := ctx.Value(contextIndexID).(string)
c.logger.Debug("evicting index from cache", "id", id, "path", req.Request.URL.Path, "method", req.Request.Method)
err := c.db.Evict(cachememdb.IndexNameID, id)
if err != nil {
c.logger.Error("failed to evict index", "id", id, "error", err)
return
}
}()
client, err := api.NewClient(api.DefaultConfig())
if err != nil {
c.logger.Error("failed to create API client in the renewer", "error", err)
return
}
client.SetToken(req.Token)
client.SetHeaders(req.Request.Header)
renewer, err := client.NewRenewer(&api.RenewerInput{
Secret: secret,
})
if err != nil {
c.logger.Error("failed to create secret renewer", "error", err)
return
}
c.logger.Debug("initiating renewal", "path", req.Request.URL.Path, "method", req.Request.Method)
go renewer.Renew()
defer renewer.Stop()
for {
select {
case <-ctx.Done():
// This is the case which captures context cancellations from token
// and leases. Since all the contexts are derived from the agent's
// context, this will also cover the shutdown scenario.
c.logger.Debug("context cancelled; stopping renewer", "path", req.Request.URL.Path)
return
case err := <-renewer.DoneCh():
// This case covers renewal completion and renewal errors
if err != nil {
c.logger.Error("failed to renew secret", "error", err)
return
}
c.logger.Debug("renewal halted; evicting from cache", "path", req.Request.URL.Path)
return
case renewal := <-renewer.RenewCh():
// This case captures secret renewals. Renewed secret is updated in
// the cached index.
c.logger.Debug("renewal received; updating cache", "path", req.Request.URL.Path)
err = c.updateResponse(ctx, renewal)
if err != nil {
c.logger.Error("failed to handle renewal", "error", err)
return
}
case <-index.RenewCtxInfo.DoneCh:
// This case indicates the renewal process to shutdown and evict
// the cache entry. This is triggered when a specific secret
// renewal needs to be killed without affecting any of the derived
// context renewals.
c.logger.Debug("done channel closed")
return
}
}
}
func (c *LeaseCache) updateResponse(ctx context.Context, renewal *api.RenewOutput) error {
id := ctx.Value(contextIndexID).(string)
// Get the cached index using the id in the context
index, err := c.db.Get(cachememdb.IndexNameID, id)
if err != nil {
return err
}
if index == nil {
return fmt.Errorf("missing cache entry for id: %q", id)
}
// Read the response from the index
resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(index.Response)), nil)
if err != nil {
c.logger.Error("failed to deserialize response", "error", err)
return err
}
// Update the body in the reponse by the renewed secret
bodyBytes, err := json.Marshal(renewal.Secret)
if err != nil {
return err
}
if resp.Body != nil {
resp.Body.Close()
}
resp.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes))
resp.ContentLength = int64(len(bodyBytes))
// Serialize the response
var respBytes bytes.Buffer
err = resp.Write(&respBytes)
if err != nil {
c.logger.Error("failed to serialize updated response", "error", err)
return err
}
// Update the response in the index and set it in the cache
index.Response = respBytes.Bytes()
err = c.db.Set(index)
if err != nil {
c.logger.Error("failed to cache the proxied response", "error", err)
return err
}
return nil
}
// computeIndexID results in a value that uniquely identifies a request
// received by the agent. It does so by SHA256 hashing the serialized request
// object containing the request path, query parameters and body parameters.
func computeIndexID(req *SendRequest) (string, error) {
var b bytes.Buffer
// Serialze the request
if err := req.Request.Write(&b); err != nil {
return "", fmt.Errorf("failed to serialize request: %v", err)
}
// Reset the request body after it has been closed by Write
if req.Request.Body != nil {
req.Request.Body.Close()
}
req.Request.Body = ioutil.NopCloser(bytes.NewBuffer(req.RequestBody))
// Append req.Token into the byte slice. This is needed since auto-auth'ed
// requests sets the token directly into SendRequest.Token
b.Write([]byte(req.Token))
sum := sha256.Sum256(b.Bytes())
return hex.EncodeToString(sum[:]), nil
}
// HandleCacheClear returns a handlerFunc that can perform cache clearing operations.
func (c *LeaseCache) HandleCacheClear(ctx context.Context) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
req := new(cacheClearRequest)
if err := jsonutil.DecodeJSONFromReader(r.Body, req); err != nil {
if err == io.EOF {
err = errors.New("empty JSON provided")
}
respondError(w, http.StatusBadRequest, errwrap.Wrapf("failed to parse JSON input: {{err}}", err))
return
}
c.logger.Debug("received cache-clear request", "type", req.Type, "namespace", req.Namespace, "value", req.Value)
if err := c.handleCacheClear(ctx, req.Type, req.Namespace, req.Value); err != nil {
// Default to 500 on error, unless the user provided an invalid type,
// which would then be a 400.
httpStatus := http.StatusInternalServerError
if err == errInvalidType {
httpStatus = http.StatusBadRequest
}
respondError(w, httpStatus, errwrap.Wrapf("failed to clear cache: {{err}}", err))
return
}
return
})
}
func (c *LeaseCache) handleCacheClear(ctx context.Context, clearType string, clearValues ...interface{}) error {
if len(clearValues) == 0 {
return errors.New("no value(s) provided to clear corresponding cache entries")
}
// The value that we want to clear, for most cases, is the last one provided.
clearValue, ok := clearValues[len(clearValues)-1].(string)
if !ok {
return fmt.Errorf("unable to convert %v to type string", clearValue)
}
switch clearType {
case "request_path":
// For this particular case, we need to ensure that there are 2 provided
// indexers for the proper lookup.
if len(clearValues) != 2 {
return fmt.Errorf("clearing cache by request path requires 2 indexers, got %d", len(clearValues))
}
// The first value provided for this case will be the namespace, but if it's
// an empty value we need to overwrite it with "root/" to ensure proper
// cache lookup.
if clearValues[0].(string) == "" {
clearValues[0] = "root/"
}
// Find all the cached entries which has the given request path and
// cancel the contexts of all the respective renewers
indexes, err := c.db.GetByPrefix(clearType, clearValues...)
if err != nil {
return err
}
for _, index := range indexes {
index.RenewCtxInfo.CancelFunc()
}
case "token":
if clearValue == "" {
return nil
}
// Get the context for the given token and cancel its context
index, err := c.db.Get(cachememdb.IndexNameToken, clearValue)
if err != nil {
return err
}
if index == nil {
return nil
}
c.logger.Debug("cancelling context of index attached to token")
index.RenewCtxInfo.CancelFunc()
case "token_accessor", "lease":
// Get the cached index and cancel the corresponding renewer context
index, err := c.db.Get(clearType, clearValue)
if err != nil {
return err
}
if index == nil {
return nil
}
c.logger.Debug("cancelling context of index attached to accessor")
index.RenewCtxInfo.CancelFunc()
case "all":
// Cancel the base context which triggers all the goroutines to
// stop and evict entries from cache.
c.logger.Debug("cancelling base context")
c.baseCtxInfo.CancelFunc()
// Reset the base context
baseCtx, baseCancel := context.WithCancel(ctx)
c.baseCtxInfo = &ContextInfo{
Ctx: baseCtx,
CancelFunc: baseCancel,
}
// Reset the memdb instance
if err := c.db.Flush(); err != nil {
return err
}
default:
return errInvalidType
}
c.logger.Debug("successfully cleared matching cache entries")
return nil
}
// handleRevocationRequest checks whether the originating request is a
// revocation request, and if so perform applicable cache cleanups.
// Returns true is this is a revocation request.
func (c *LeaseCache) handleRevocationRequest(ctx context.Context, req *SendRequest, resp *SendResponse) (bool, error) {
// Lease and token revocations return 204's on success. Fast-path if that's
// not the case.
if resp.Response.StatusCode != http.StatusNoContent {
return false, nil
}
_, path := deriveNamespaceAndRevocationPath(req)
switch {
case path == vaultPathTokenRevoke:
// Get the token from the request body
jsonBody := map[string]interface{}{}
if err := json.Unmarshal(req.RequestBody, &jsonBody); err != nil {
return false, err
}
tokenRaw, ok := jsonBody["token"]
if !ok {
return false, fmt.Errorf("failed to get token from request body")
}
token, ok := tokenRaw.(string)
if !ok {
return false, fmt.Errorf("expected token in the request body to be string")
}
// Clear the cache entry associated with the token and all the other
// entries belonging to the leases derived from this token.
if err := c.handleCacheClear(ctx, "token", token); err != nil {
return false, err
}
case path == vaultPathTokenRevokeSelf:
// Clear the cache entry associated with the token and all the other
// entries belonging to the leases derived from this token.
if err := c.handleCacheClear(ctx, "token", req.Token); err != nil {
return false, err
}
case path == vaultPathTokenRevokeAccessor:
jsonBody := map[string]interface{}{}
if err := json.Unmarshal(req.RequestBody, &jsonBody); err != nil {
return false, err
}
accessorRaw, ok := jsonBody["accessor"]
if !ok {
return false, fmt.Errorf("failed to get accessor from request body")
}
accessor, ok := accessorRaw.(string)
if !ok {
return false, fmt.Errorf("expected accessor in the request body to be string")
}
if err := c.handleCacheClear(ctx, "token_accessor", accessor); err != nil {
return false, err
}
case path == vaultPathTokenRevokeOrphan:
jsonBody := map[string]interface{}{}
if err := json.Unmarshal(req.RequestBody, &jsonBody); err != nil {
return false, err
}
tokenRaw, ok := jsonBody["token"]
if !ok {
return false, fmt.Errorf("failed to get token from request body")
}
token, ok := tokenRaw.(string)
if !ok {
return false, fmt.Errorf("expected token in the request body to be string")
}
// Kill the renewers of all the leases attached to the revoked token
indexes, err := c.db.GetByPrefix(cachememdb.IndexNameLeaseToken, token)
if err != nil {
return false, err
}
for _, index := range indexes {
index.RenewCtxInfo.CancelFunc()
}
// Kill the renewer of the revoked token
index, err := c.db.Get(cachememdb.IndexNameToken, token)
if err != nil {
return false, err
}
if index == nil {
return true, nil
}
// Indicate the renewer goroutine for this index to return. This will
// not affect the child tokens because the context is not getting
// cancelled.
close(index.RenewCtxInfo.DoneCh)
// Clear the parent references of the revoked token in the entries
// belonging to the child tokens of the revoked token.
indexes, err = c.db.GetByPrefix(cachememdb.IndexNameTokenParent, token)
if err != nil {
return false, err
}
for _, index := range indexes {
index.TokenParent = ""
err = c.db.Set(index)
if err != nil {
c.logger.Error("failed to persist index", "error", err)
return false, err
}
}
case path == vaultPathLeaseRevoke:
// TODO: Should lease present in the URL itself be considered here?
// Get the lease from the request body
jsonBody := map[string]interface{}{}
if err := json.Unmarshal(req.RequestBody, &jsonBody); err != nil {
return false, err
}
leaseIDRaw, ok := jsonBody["lease_id"]
if !ok {
return false, fmt.Errorf("failed to get lease_id from request body")
}
leaseID, ok := leaseIDRaw.(string)
if !ok {
return false, fmt.Errorf("expected lease_id the request body to be string")
}
if err := c.handleCacheClear(ctx, "lease", leaseID); err != nil {
return false, err
}
case strings.HasPrefix(path, vaultPathLeaseRevokeForce):
// Trim the URL path to get the request path prefix
prefix := strings.TrimPrefix(path, vaultPathLeaseRevokeForce)
// Get all the cache indexes that use the request path containing the
// prefix and cancel the renewer context of each.
indexes, err := c.db.GetByPrefix(cachememdb.IndexNameLease, prefix)
if err != nil {
return false, err
}
_, tokenNSID := namespace.SplitIDFromString(req.Token)
for _, index := range indexes {
_, leaseNSID := namespace.SplitIDFromString(index.Lease)
// Only evict leases that match the token's namespace
if tokenNSID == leaseNSID {
index.RenewCtxInfo.CancelFunc()
}
}
case strings.HasPrefix(path, vaultPathLeaseRevokePrefix):
// Trim the URL path to get the request path prefix
prefix := strings.TrimPrefix(path, vaultPathLeaseRevokePrefix)
// Get all the cache indexes that use the request path containing the
// prefix and cancel the renewer context of each.
indexes, err := c.db.GetByPrefix(cachememdb.IndexNameLease, prefix)
if err != nil {
return false, err
}
_, tokenNSID := namespace.SplitIDFromString(req.Token)
for _, index := range indexes {
_, leaseNSID := namespace.SplitIDFromString(index.Lease)
// Only evict leases that match the token's namespace
if tokenNSID == leaseNSID {
index.RenewCtxInfo.CancelFunc()
}
}
default:
return false, nil
}
c.logger.Debug("triggered caching eviction from revocation request")
return true, nil
}
// deriveNamespaceAndRevocationPath returns the namespace and relative path for
// revocation paths.
//
// If the path contains a namespace, but it's not a revocation path, it will be
// returned as-is, since there's no way to tell where the namespace ends and
// where the request path begins purely based off a string.
//
// Case 1: /v1/ns1/leases/revoke -> ns1/, /v1/leases/revoke
// Case 2: ns1/ /v1/leases/revoke -> ns1/, /v1/leases/revoke
// Case 3: /v1/ns1/foo/bar -> root/, /v1/ns1/foo/bar
// Case 4: ns1/ /v1/foo/bar -> ns1/, /v1/foo/bar
func deriveNamespaceAndRevocationPath(req *SendRequest) (string, string) {
namespace := "root/"
nsHeader := req.Request.Header.Get(consts.NamespaceHeaderName)
if nsHeader != "" {
namespace = nsHeader
}
fullPath := req.Request.URL.Path
nonVersionedPath := strings.TrimPrefix(fullPath, "/v1")
for _, pathToCheck := range revocationPaths {
// We use strings.Contains here for paths that can contain
// vars in the path, e.g. /v1/lease/revoke-prefix/:prefix
i := strings.Index(nonVersionedPath, pathToCheck)
// If there's no match, move on to the next check
if i == -1 {
continue
}
// If the index is 0, this is a relative path with no namespace preppended,
// so we can break early
if i == 0 {
break
}
// We need to turn /ns1 into ns1/, this makes it easy
namespaceInPath := nshelper.Canonicalize(nonVersionedPath[:i])
// If it's root, we replace, otherwise we join
if namespace == "root/" {
namespace = namespaceInPath
} else {
namespace = namespace + namespaceInPath
}
return namespace, fmt.Sprintf("/v1%s", nonVersionedPath[i:])
}
return namespace, fmt.Sprintf("/v1%s", nonVersionedPath)
}

507
command/agent/cache/lease_cache_test.go vendored Normal file
View File

@ -0,0 +1,507 @@
package cache
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strings"
"testing"
"github.com/go-test/deep"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/logging"
)
func testNewLeaseCache(t *testing.T, responses []*SendResponse) *LeaseCache {
t.Helper()
lc, err := NewLeaseCache(&LeaseCacheConfig{
BaseContext: context.Background(),
Proxier: newMockProxier(responses),
Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.leasecache"),
})
if err != nil {
t.Fatal(err)
}
return lc
}
func TestCache_ComputeIndexID(t *testing.T) {
type args struct {
req *http.Request
}
tests := []struct {
name string
req *SendRequest
want string
wantErr bool
}{
{
"basic",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "test",
},
},
},
"2edc7e965c3e1bdce3b1d5f79a52927842569c0734a86544d222753f11ae4847",
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := computeIndexID(tt.req)
if (err != nil) != tt.wantErr {
t.Errorf("actual_error: %v, expected_error: %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, string(tt.want)) {
t.Errorf("bad: index id; actual: %q, expected: %q", got, string(tt.want))
}
})
}
}
func TestCache_LeaseCache_EmptyToken(t *testing.T) {
responses := []*SendResponse{
&SendResponse{
Response: &api.Response{
Response: &http.Response{
StatusCode: http.StatusCreated,
Body: ioutil.NopCloser(strings.NewReader(`{"value": "invalid", "auth": {"client_token": "testtoken"}}`)),
},
},
ResponseBody: []byte(`{"value": "invalid", "auth": {"client_token": "testtoken"}}`),
},
}
lc := testNewLeaseCache(t, responses)
// Even if the send request doesn't have a token on it, a successful
// cacheable response should result in the index properly getting populated
// with a token and memdb shouldn't complain while inserting the index.
urlPath := "http://example.com/v1/sample/api"
sendReq := &SendRequest{
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)),
}
resp, err := lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatalf("expected a non empty response")
}
}
func TestCache_LeaseCache_SendCacheable(t *testing.T) {
// Emulate 2 responses from the api proxy. One returns a new token and the
// other returns a lease.
responses := []*SendResponse{
&SendResponse{
Response: &api.Response{
Response: &http.Response{
StatusCode: http.StatusCreated,
Body: ioutil.NopCloser(strings.NewReader(`{"value": "invalid", "auth": {"client_token": "testtoken", "renewable": true}}`)),
},
},
ResponseBody: []byte(`{"value": "invalid", "auth": {"client_token": "testtoken", "renewable": true}}`),
},
&SendResponse{
Response: &api.Response{
Response: &http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(strings.NewReader(`{"value": "output", "lease_id": "foo", "renewable": true}`)),
},
},
ResponseBody: []byte(`{"value": "output", "lease_id": "foo", "renewable": true}`),
},
}
lc := testNewLeaseCache(t, responses)
// Make a request. A response with a new token is returned to the lease
// cache and that will be cached.
urlPath := "http://example.com/v1/sample/api"
sendReq := &SendRequest{
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)),
}
resp, err := lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff != nil {
t.Fatalf("expected getting proxied response: got %v", diff)
}
// Send the same request again to get the cached response
sendReq = &SendRequest{
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)),
}
resp, err = lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff != nil {
t.Fatalf("expected getting proxied response: got %v", diff)
}
// Modify the request a little bit to ensure the second response is
// returned to the lease cache. But make sure that the token in the request
// is valid.
sendReq = &SendRequest{
Token: "testtoken",
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input_changed"}`)),
}
resp, err = lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(resp.Response.StatusCode, responses[1].Response.StatusCode); diff != nil {
t.Fatalf("expected getting proxied response: got %v", diff)
}
// Make the same request again and ensure that the same reponse is returned
// again.
sendReq = &SendRequest{
Token: "testtoken",
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input_changed"}`)),
}
resp, err = lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(resp.Response.StatusCode, responses[1].Response.StatusCode); diff != nil {
t.Fatalf("expected getting proxied response: got %v", diff)
}
}
func TestCache_LeaseCache_SendNonCacheable(t *testing.T) {
responses := []*SendResponse{
&SendResponse{
Response: &api.Response{
Response: &http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(strings.NewReader(`{"value": "output"}`)),
},
},
},
&SendResponse{
Response: &api.Response{
Response: &http.Response{
StatusCode: http.StatusNotFound,
Body: ioutil.NopCloser(strings.NewReader(`{"value": "invalid"}`)),
},
},
},
}
lc := testNewLeaseCache(t, responses)
// Send a request through the lease cache which is not cacheable (there is
// no lease information or auth information in the response)
sendReq := &SendRequest{
Request: httptest.NewRequest("GET", "http://example.com", strings.NewReader(`{"value": "input"}`)),
}
resp, err := lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff != nil {
t.Fatalf("expected getting proxied response: got %v", diff)
}
// Since the response is non-cacheable, the second response will be
// returned.
sendReq = &SendRequest{
Token: "foo",
Request: httptest.NewRequest("GET", "http://example.com", strings.NewReader(`{"value": "input"}`)),
}
resp, err = lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(resp.Response.StatusCode, responses[1].Response.StatusCode); diff != nil {
t.Fatalf("expected getting proxied response: got %v", diff)
}
}
func TestCache_LeaseCache_SendNonCacheableNonTokenLease(t *testing.T) {
// Create the cache
responses := []*SendResponse{
&SendResponse{
Response: &api.Response{
Response: &http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(strings.NewReader(`{"value": "output", "lease_id": "foo"}`)),
},
},
ResponseBody: []byte(`{"value": "output", "lease_id": "foo"}`),
},
&SendResponse{
Response: &api.Response{
Response: &http.Response{
StatusCode: http.StatusCreated,
Body: ioutil.NopCloser(strings.NewReader(`{"value": "invalid", "auth": {"client_token": "testtoken"}}`)),
},
},
ResponseBody: []byte(`{"value": "invalid", "auth": {"client_token": "testtoken"}}`),
},
}
lc := testNewLeaseCache(t, responses)
// Send a request through lease cache which returns a response containing
// lease_id. Response will not be cached because it doesn't belong to a
// token that is managed by the lease cache.
urlPath := "http://example.com/v1/sample/api"
sendReq := &SendRequest{
Token: "foo",
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)),
}
resp, err := lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff != nil {
t.Fatalf("expected getting proxied response: got %v", diff)
}
// Verify that the response is not cached by sending the same request and
// by expecting a different response.
sendReq = &SendRequest{
Token: "foo",
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)),
}
resp, err = lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff == nil {
t.Fatalf("expected getting proxied response: got %v", diff)
}
}
func TestCache_LeaseCache_HandleCacheClear(t *testing.T) {
lc := testNewLeaseCache(t, nil)
handler := lc.HandleCacheClear(context.Background())
ts := httptest.NewServer(handler)
defer ts.Close()
// Test missing body, should return 400
resp, err := http.Post(ts.URL, "application/json", nil)
if err != nil {
t.Fatal()
}
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("status code mismatch: expected = %v, got = %v", http.StatusBadRequest, resp.StatusCode)
}
testCases := []struct {
name string
reqType string
reqValue string
expectedStatusCode int
}{
{
"invalid_type",
"foo",
"",
http.StatusBadRequest,
},
{
"invalid_value",
"",
"bar",
http.StatusBadRequest,
},
{
"all",
"all",
"",
http.StatusOK,
},
{
"by_request_path",
"request_path",
"foo",
http.StatusOK,
},
{
"by_token",
"token",
"foo",
http.StatusOK,
},
{
"by_lease",
"lease",
"foo",
http.StatusOK,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
reqBody := fmt.Sprintf("{\"type\": \"%s\", \"value\": \"%s\"}", tc.reqType, tc.reqValue)
resp, err := http.Post(ts.URL, "application/json", strings.NewReader(reqBody))
if err != nil {
t.Fatal(err)
}
if tc.expectedStatusCode != resp.StatusCode {
t.Fatalf("status code mismatch: expected = %v, got = %v", tc.expectedStatusCode, resp.StatusCode)
}
})
}
}
func TestCache_DeriveNamespaceAndRevocationPath(t *testing.T) {
tests := []struct {
name string
req *SendRequest
wantNamespace string
wantRelativePath string
}{
{
"non_revocation_full_path",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "/v1/ns1/sys/mounts",
},
},
},
"root/",
"/v1/ns1/sys/mounts",
},
{
"non_revocation_relative_path",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "/v1/sys/mounts",
},
Header: http.Header{
consts.NamespaceHeaderName: []string{"ns1/"},
},
},
},
"ns1/",
"/v1/sys/mounts",
},
{
"non_revocation_relative_path",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "/v1/ns2/sys/mounts",
},
Header: http.Header{
consts.NamespaceHeaderName: []string{"ns1/"},
},
},
},
"ns1/",
"/v1/ns2/sys/mounts",
},
{
"revocation_full_path",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "/v1/ns1/sys/leases/revoke",
},
},
},
"ns1/",
"/v1/sys/leases/revoke",
},
{
"revocation_relative_path",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "/v1/sys/leases/revoke",
},
Header: http.Header{
consts.NamespaceHeaderName: []string{"ns1/"},
},
},
},
"ns1/",
"/v1/sys/leases/revoke",
},
{
"revocation_relative_partial_ns",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "/v1/ns2/sys/leases/revoke",
},
Header: http.Header{
consts.NamespaceHeaderName: []string{"ns1/"},
},
},
},
"ns1/ns2/",
"/v1/sys/leases/revoke",
},
{
"revocation_prefix_full_path",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "/v1/ns1/sys/leases/revoke-prefix/foo",
},
},
},
"ns1/",
"/v1/sys/leases/revoke-prefix/foo",
},
{
"revocation_prefix_relative_path",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "/v1/sys/leases/revoke-prefix/foo",
},
Header: http.Header{
consts.NamespaceHeaderName: []string{"ns1/"},
},
},
},
"ns1/",
"/v1/sys/leases/revoke-prefix/foo",
},
{
"revocation_prefix_partial_ns",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "/v1/ns2/sys/leases/revoke-prefix/foo",
},
Header: http.Header{
consts.NamespaceHeaderName: []string{"ns1/"},
},
},
},
"ns1/ns2/",
"/v1/sys/leases/revoke-prefix/foo",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotNamespace, gotRelativePath := deriveNamespaceAndRevocationPath(tt.req)
if gotNamespace != tt.wantNamespace {
t.Errorf("deriveNamespaceAndRevocationPath() gotNamespace = %v, want %v", gotNamespace, tt.wantNamespace)
}
if gotRelativePath != tt.wantRelativePath {
t.Errorf("deriveNamespaceAndRevocationPath() gotRelativePath = %v, want %v", gotRelativePath, tt.wantRelativePath)
}
})
}
}

105
command/agent/cache/listener.go vendored Normal file
View File

@ -0,0 +1,105 @@
package cache
import (
"fmt"
"io"
"net"
"os"
"strings"
"github.com/hashicorp/vault/command/agent/config"
"github.com/hashicorp/vault/command/server"
"github.com/hashicorp/vault/helper/reload"
"github.com/mitchellh/cli"
)
func ServerListener(lnConfig *config.Listener, logger io.Writer, ui cli.Ui) (net.Listener, map[string]string, reload.ReloadFunc, error) {
switch lnConfig.Type {
case "unix":
return unixSocketListener(lnConfig.Config, logger, ui)
case "tcp":
return tcpListener(lnConfig.Config, logger, ui)
default:
return nil, nil, nil, fmt.Errorf("unsupported listener type: %q", lnConfig.Type)
}
}
func unixSocketListener(config map[string]interface{}, _ io.Writer, ui cli.Ui) (net.Listener, map[string]string, reload.ReloadFunc, error) {
addr, ok := config["address"].(string)
if !ok {
return nil, nil, nil, fmt.Errorf("invalid address: %v", config["address"])
}
if addr == "" {
return nil, nil, nil, fmt.Errorf("address field should point to socket file path")
}
// Remove the socket file as it shouldn't exist for the domain socket to
// work
err := os.Remove(addr)
if err != nil && !os.IsNotExist(err) {
return nil, nil, nil, fmt.Errorf("failed to remove the socket file: %v", err)
}
listener, err := net.Listen("unix", addr)
if err != nil {
return nil, nil, nil, err
}
// Wrap the listener in rmListener so that the Unix domain socket file is
// removed on close.
listener = &rmListener{
Listener: listener,
Path: addr,
}
props := map[string]string{"addr": addr, "tls": "disabled"}
return listener, props, nil, nil
}
func tcpListener(config map[string]interface{}, _ io.Writer, ui cli.Ui) (net.Listener, map[string]string, reload.ReloadFunc, error) {
bindProto := "tcp"
var addr string
addrRaw, ok := config["address"]
if !ok {
addr = "127.0.0.1:8300"
} else {
addr = addrRaw.(string)
}
// If they've passed 0.0.0.0, we only want to bind on IPv4
// rather than golang's dual stack default
if strings.HasPrefix(addr, "0.0.0.0:") {
bindProto = "tcp4"
}
ln, err := net.Listen(bindProto, addr)
if err != nil {
return nil, nil, nil, err
}
ln = server.TCPKeepAliveListener{ln.(*net.TCPListener)}
props := map[string]string{"addr": addr}
return server.ListenerWrapTLS(ln, props, config, ui)
}
// rmListener is an implementation of net.Listener that forwards most
// calls to the listener but also removes a file as part of the close. We
// use this to cleanup the unix domain socket on close.
type rmListener struct {
net.Listener
Path string
}
func (l *rmListener) Close() error {
// Close the listener itself
if err := l.Listener.Close(); err != nil {
return err
}
// Remove the file
return os.Remove(l.Path)
}

28
command/agent/cache/proxy.go vendored Normal file
View File

@ -0,0 +1,28 @@
package cache
import (
"context"
"net/http"
"github.com/hashicorp/vault/api"
)
// SendRequest is the input for Proxier.Send.
type SendRequest struct {
Token string
Request *http.Request
RequestBody []byte
}
// SendResponse is the output from Proxier.Send.
type SendResponse struct {
Response *api.Response
ResponseBody []byte
}
// Proxier is the interface implemented by different components that are
// responsible for performing specific tasks, such as caching and proxying. All
// these tasks combined together would serve the request received by the agent.
type Proxier interface {
Send(ctx context.Context, req *SendRequest) (*SendResponse, error)
}

36
command/agent/cache/testing.go vendored Normal file
View File

@ -0,0 +1,36 @@
package cache
import (
"context"
"fmt"
)
// mockProxier is a mock implementation of the Proxier interface, used for testing purposes.
// The mock will return the provided responses every time it reaches its Send method, up to
// the last provided response. This lets tests control what the next/underlying Proxier layer
// might expect to return.
type mockProxier struct {
proxiedResponses []*SendResponse
responseIndex int
}
func newMockProxier(responses []*SendResponse) *mockProxier {
return &mockProxier{
proxiedResponses: responses,
}
}
func (p *mockProxier) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) {
if p.responseIndex >= len(p.proxiedResponses) {
return nil, fmt.Errorf("index out of bounds: responseIndex = %d, responses = %d", p.responseIndex, len(p.proxiedResponses))
}
resp := p.proxiedResponses[p.responseIndex]
p.responseIndex++
return resp, nil
}
func (p *mockProxier) ResponseIndex() int {
return p.responseIndex
}

View File

@ -0,0 +1,280 @@
package agent
import (
"context"
"fmt"
"io/ioutil"
"net"
"net/http"
"os"
"testing"
"time"
hclog "github.com/hashicorp/go-hclog"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
credAppRole "github.com/hashicorp/vault/builtin/credential/approle"
"github.com/hashicorp/vault/command/agent/auth"
agentapprole "github.com/hashicorp/vault/command/agent/auth/approle"
"github.com/hashicorp/vault/command/agent/cache"
"github.com/hashicorp/vault/command/agent/sink"
"github.com/hashicorp/vault/command/agent/sink/file"
"github.com/hashicorp/vault/helper/logging"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault"
)
func TestCache_UsingAutoAuthToken(t *testing.T) {
var err error
logger := logging.NewVaultLogger(log.Trace)
coreConfig := &vault.CoreConfig{
DisableMlock: true,
DisableCache: true,
Logger: log.NewNullLogger(),
CredentialBackends: map[string]logical.Factory{
"approle": credAppRole.Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
cores := cluster.Cores
vault.TestWaitActive(t, cores[0].Core)
client := cores[0].Client
defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress))
os.Setenv(api.EnvVaultAddress, client.Address())
defer os.Setenv(api.EnvVaultCACert, os.Getenv(api.EnvVaultCACert))
os.Setenv(api.EnvVaultCACert, fmt.Sprintf("%s/ca_cert.pem", cluster.TempDir))
err = client.Sys().EnableAuthWithOptions("approle", &api.EnableAuthOptions{
Type: "approle",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("auth/approle/role/test1", map[string]interface{}{
"bind_secret_id": "true",
"token_ttl": "3s",
"token_max_ttl": "10s",
})
if err != nil {
t.Fatal(err)
}
resp, err := client.Logical().Write("auth/approle/role/test1/secret-id", nil)
if err != nil {
t.Fatal(err)
}
secretID1 := resp.Data["secret_id"].(string)
resp, err = client.Logical().Read("auth/approle/role/test1/role-id")
if err != nil {
t.Fatal(err)
}
roleID1 := resp.Data["role_id"].(string)
rolef, err := ioutil.TempFile("", "auth.role-id.test.")
if err != nil {
t.Fatal(err)
}
role := rolef.Name()
rolef.Close() // WriteFile doesn't need it open
defer os.Remove(role)
t.Logf("input role_id_file_path: %s", role)
secretf, err := ioutil.TempFile("", "auth.secret-id.test.")
if err != nil {
t.Fatal(err)
}
secret := secretf.Name()
secretf.Close()
defer os.Remove(secret)
t.Logf("input secret_id_file_path: %s", secret)
// We close these right away because we're just basically testing
// permissions and finding a usable file name
ouf, err := ioutil.TempFile("", "auth.tokensink.test.")
if err != nil {
t.Fatal(err)
}
out := ouf.Name()
ouf.Close()
os.Remove(out)
t.Logf("output: %s", out)
ctx, cancelFunc := context.WithCancel(context.Background())
timer := time.AfterFunc(30*time.Second, func() {
cancelFunc()
})
defer timer.Stop()
conf := map[string]interface{}{
"role_id_file_path": role,
"secret_id_file_path": secret,
"remove_secret_id_file_after_reading": true,
}
am, err := agentapprole.NewApproleAuthMethod(&auth.AuthConfig{
Logger: logger.Named("auth.approle"),
MountPath: "auth/approle",
Config: conf,
})
if err != nil {
t.Fatal(err)
}
ahConfig := &auth.AuthHandlerConfig{
Logger: logger.Named("auth.handler"),
Client: client,
}
ah := auth.NewAuthHandler(ahConfig)
go ah.Run(ctx, am)
defer func() {
<-ah.DoneCh
}()
config := &sink.SinkConfig{
Logger: logger.Named("sink.file"),
Config: map[string]interface{}{
"path": out,
},
}
fs, err := file.NewFileSink(config)
if err != nil {
t.Fatal(err)
}
config.Sink = fs
ss := sink.NewSinkServer(&sink.SinkServerConfig{
Logger: logger.Named("sink.server"),
Client: client,
})
go ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config})
defer func() {
<-ss.DoneCh
}()
// This has to be after the other defers so it happens first
defer cancelFunc()
// Check that no sink file exists
_, err = os.Lstat(out)
if err == nil {
t.Fatal("expected err")
}
if !os.IsNotExist(err) {
t.Fatal("expected notexist err")
}
if err := ioutil.WriteFile(role, []byte(roleID1), 0600); err != nil {
t.Fatal(err)
} else {
logger.Trace("wrote test role 1", "path", role)
}
if err := ioutil.WriteFile(secret, []byte(secretID1), 0600); err != nil {
t.Fatal(err)
} else {
logger.Trace("wrote test secret 1", "path", secret)
}
getToken := func() string {
timeout := time.Now().Add(10 * time.Second)
for {
if time.Now().After(timeout) {
t.Fatal("did not find a written token after timeout")
}
val, err := ioutil.ReadFile(out)
if err == nil {
os.Remove(out)
if len(val) == 0 {
t.Fatal("written token was empty")
}
_, err = os.Stat(secret)
if err == nil {
t.Fatal("secret file exists but was supposed to be removed")
}
client.SetToken(string(val))
_, err := client.Auth().Token().LookupSelf()
if err != nil {
t.Fatal(err)
}
return string(val)
}
time.Sleep(250 * time.Millisecond)
}
}
t.Logf("auto-auth token: %q", getToken())
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer listener.Close()
cacheLogger := logging.NewVaultLogger(hclog.Trace).Named("cache")
// Create the API proxier
apiProxy := cache.NewAPIProxy(&cache.APIProxyConfig{
Logger: cacheLogger.Named("apiproxy"),
})
// Create the lease cache proxier and set its underlying proxier to
// the API proxier.
leaseCache, err := cache.NewLeaseCache(&cache.LeaseCacheConfig{
BaseContext: ctx,
Proxier: apiProxy,
Logger: cacheLogger.Named("leasecache"),
})
if err != nil {
t.Fatal(err)
}
// Create a muxer and add paths relevant for the lease cache layer
mux := http.NewServeMux()
mux.Handle("/v1/agent/cache-clear", leaseCache.HandleCacheClear(ctx))
mux.Handle("/", cache.Handler(ctx, cacheLogger, leaseCache, true, client))
server := &http.Server{
Handler: mux,
ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 30 * time.Second,
IdleTimeout: 5 * time.Minute,
ErrorLog: cacheLogger.StandardLogger(nil),
}
go server.Serve(listener)
testClient, err := api.NewClient(api.DefaultConfig())
if err != nil {
t.Fatal(err)
}
if err := testClient.SetAddress("http://" + listener.Addr().String()); err != nil {
t.Fatal(err)
}
// Wait for listeners to come up
time.Sleep(2 * time.Second)
resp, err = testClient.Logical().Read("auth/token/lookup-self")
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatalf("failed to use the auto-auth token to perform lookup-self")
}
}

View File

@ -22,6 +22,17 @@ type Config struct {
AutoAuth *AutoAuth `hcl:"auto_auth"`
ExitAfterAuth bool `hcl:"exit_after_auth"`
PidFile string `hcl:"pid_file"`
Cache *Cache `hcl:"cache"`
}
type Cache struct {
UseAutoAuthToken bool `hcl:"use_auto_auth_token"`
Listeners []*Listener `hcl:"listeners"`
}
type Listener struct {
Type string
Config map[string]interface{}
}
type AutoAuth struct {
@ -91,9 +102,102 @@ func LoadConfig(path string, logger log.Logger) (*Config, error) {
return nil, errwrap.Wrapf("error parsing 'auto_auth': {{err}}", err)
}
err = parseCache(&result, list)
if err != nil {
return nil, errwrap.Wrapf("error parsing 'cache':{{err}}", err)
}
return &result, nil
}
func parseCache(result *Config, list *ast.ObjectList) error {
name := "cache"
cacheList := list.Filter(name)
if len(cacheList.Items) == 0 {
return nil
}
if len(cacheList.Items) > 1 {
return fmt.Errorf("one and only one %q block is required", name)
}
item := cacheList.Items[0]
var c Cache
err := hcl.DecodeObject(&c, item.Val)
if err != nil {
return err
}
result.Cache = &c
subs, ok := item.Val.(*ast.ObjectType)
if !ok {
return fmt.Errorf("could not parse %q as an object", name)
}
subList := subs.List
err = parseListeners(result, subList)
if err != nil {
return errwrap.Wrapf("error parsing 'listener' stanzas: {{err}}", err)
}
return nil
}
func parseListeners(result *Config, list *ast.ObjectList) error {
name := "listener"
listenerList := list.Filter(name)
if len(listenerList.Items) < 1 {
return fmt.Errorf("at least one %q block is required", name)
}
var listeners []*Listener
for _, item := range listenerList.Items {
var lnConfig map[string]interface{}
err := hcl.DecodeObject(&lnConfig, item.Val)
if err != nil {
return err
}
var lnType string
switch {
case lnConfig["type"] != nil:
lnType = lnConfig["type"].(string)
delete(lnConfig, "type")
case len(item.Keys) == 1:
lnType = strings.ToLower(item.Keys[0].Token.Value().(string))
default:
return errors.New("listener type must be specified")
}
switch lnType {
case "unix":
// Don't accept TLS connection information for unix domain socket
// listener. Maybe something to support in future.
unixLnConfig := map[string]interface{}{
"tls_disable": true,
}
unixLnConfig["address"] = lnConfig["address"]
lnConfig = unixLnConfig
case "tcp":
default:
return fmt.Errorf("invalid listener type %q", lnType)
}
listeners = append(listeners, &Listener{
Type: lnType,
Config: lnConfig,
})
}
result.Cache.Listeners = listeners
return nil
}
func parseAutoAuth(result *Config, list *ast.ObjectList) error {
name := "auto_auth"

View File

@ -10,6 +10,80 @@ import (
"github.com/hashicorp/vault/helper/logging"
)
func TestLoadConfigFile_AgentCache(t *testing.T) {
logger := logging.NewVaultLogger(log.Debug)
config, err := LoadConfig("./test-fixtures/config-cache.hcl", logger)
if err != nil {
t.Fatal(err)
}
expected := &Config{
AutoAuth: &AutoAuth{
Method: &Method{
Type: "aws",
WrapTTL: 300 * time.Second,
MountPath: "auth/aws",
Config: map[string]interface{}{
"role": "foobar",
},
},
Sinks: []*Sink{
&Sink{
Type: "file",
DHType: "curve25519",
DHPath: "/tmp/file-foo-dhpath",
AAD: "foobar",
Config: map[string]interface{}{
"path": "/tmp/file-foo",
},
},
},
},
Cache: &Cache{
UseAutoAuthToken: true,
Listeners: []*Listener{
&Listener{
Type: "unix",
Config: map[string]interface{}{
"address": "/path/to/socket",
"tls_disable": true,
},
},
&Listener{
Type: "tcp",
Config: map[string]interface{}{
"address": "127.0.0.1:8300",
"tls_disable": true,
},
},
&Listener{
Type: "tcp",
Config: map[string]interface{}{
"address": "127.0.0.1:8400",
"tls_key_file": "/path/to/cakey.pem",
"tls_cert_file": "/path/to/cacert.pem",
},
},
},
},
PidFile: "./pidfile",
}
if diff := deep.Equal(config, expected); diff != nil {
t.Fatal(diff)
}
config, err = LoadConfig("./test-fixtures/config-cache-embedded-type.hcl", logger)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(config, expected); diff != nil {
t.Fatal(diff)
}
}
func TestLoadConfigFile(t *testing.T) {
logger := logging.NewVaultLogger(log.Debug)

View File

@ -0,0 +1,44 @@
pid_file = "./pidfile"
auto_auth {
method {
type = "aws"
wrap_ttl = 300
config = {
role = "foobar"
}
}
sink {
type = "file"
config = {
path = "/tmp/file-foo"
}
aad = "foobar"
dh_type = "curve25519"
dh_path = "/tmp/file-foo-dhpath"
}
}
cache {
use_auto_auth_token = true
listener {
type = "unix"
address = "/path/to/socket"
tls_disable = true
}
listener {
type = "tcp"
address = "127.0.0.1:8300"
tls_disable = true
}
listener {
type = "tcp"
address = "127.0.0.1:8400"
tls_key_file = "/path/to/cakey.pem"
tls_cert_file = "/path/to/cacert.pem"
}
}

View File

@ -0,0 +1,41 @@
pid_file = "./pidfile"
auto_auth {
method {
type = "aws"
wrap_ttl = 300
config = {
role = "foobar"
}
}
sink {
type = "file"
config = {
path = "/tmp/file-foo"
}
aad = "foobar"
dh_type = "curve25519"
dh_path = "/tmp/file-foo-dhpath"
}
}
cache {
use_auto_auth_token = true
listener "unix" {
address = "/path/to/socket"
tls_disable = true
}
listener "tcp" {
address = "127.0.0.1:8300"
tls_disable = true
}
listener "tcp" {
address = "127.0.0.1:8400"
tls_key_file = "/path/to/cakey.pem"
tls_cert_file = "/path/to/cacert.pem"
}
}

View File

@ -5,6 +5,7 @@ import (
"io/ioutil"
"os"
"testing"
"time"
hclog "github.com/hashicorp/go-hclog"
vaultjwt "github.com/hashicorp/vault-plugin-auth-jwt"
@ -30,6 +31,188 @@ func testAgentCommand(tb testing.TB, logger hclog.Logger) (*cli.MockUi, *AgentCo
}
}
func TestAgent_Cache_UnixListener(t *testing.T) {
logger := logging.NewVaultLogger(hclog.Trace)
coreConfig := &vault.CoreConfig{
Logger: logger.Named("core"),
CredentialBackends: map[string]logical.Factory{
"jwt": vaultjwt.Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
vault.TestWaitActive(t, cluster.Cores[0].Core)
client := cluster.Cores[0].Client
defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress))
os.Setenv(api.EnvVaultAddress, client.Address())
defer os.Setenv(api.EnvVaultCACert, os.Getenv(api.EnvVaultCACert))
os.Setenv(api.EnvVaultCACert, fmt.Sprintf("%s/ca_cert.pem", cluster.TempDir))
// Setup Vault
err := client.Sys().EnableAuthWithOptions("jwt", &api.EnableAuthOptions{
Type: "jwt",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("auth/jwt/config", map[string]interface{}{
"bound_issuer": "https://team-vault.auth0.com/",
"jwt_validation_pubkeys": agent.TestECDSAPubKey,
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("auth/jwt/role/test", map[string]interface{}{
"bound_subject": "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
"bound_audiences": "https://vault.plugin.auth.jwt.test",
"user_claim": "https://vault/user",
"groups_claim": "https://vault/groups",
"policies": "test",
"period": "3s",
})
if err != nil {
t.Fatal(err)
}
inf, err := ioutil.TempFile("", "auth.jwt.test.")
if err != nil {
t.Fatal(err)
}
in := inf.Name()
inf.Close()
os.Remove(in)
t.Logf("input: %s", in)
sink1f, err := ioutil.TempFile("", "sink1.jwt.test.")
if err != nil {
t.Fatal(err)
}
sink1 := sink1f.Name()
sink1f.Close()
os.Remove(sink1)
t.Logf("sink1: %s", sink1)
sink2f, err := ioutil.TempFile("", "sink2.jwt.test.")
if err != nil {
t.Fatal(err)
}
sink2 := sink2f.Name()
sink2f.Close()
os.Remove(sink2)
t.Logf("sink2: %s", sink2)
conff, err := ioutil.TempFile("", "conf.jwt.test.")
if err != nil {
t.Fatal(err)
}
conf := conff.Name()
conff.Close()
os.Remove(conf)
t.Logf("config: %s", conf)
jwtToken, _ := agent.GetTestJWT(t)
if err := ioutil.WriteFile(in, []byte(jwtToken), 0600); err != nil {
t.Fatal(err)
} else {
logger.Trace("wrote test jwt", "path", in)
}
socketff, err := ioutil.TempFile("", "cache.socket.")
if err != nil {
t.Fatal(err)
}
socketf := socketff.Name()
socketff.Close()
os.Remove(socketf)
t.Logf("socketf: %s", socketf)
config := `
auto_auth {
method {
type = "jwt"
config = {
role = "test"
path = "%s"
}
}
sink {
type = "file"
config = {
path = "%s"
}
}
sink "file" {
config = {
path = "%s"
}
}
}
cache {
use_auto_auth_token = true
listener "unix" {
address = "%s"
tls_disable = true
}
}
`
config = fmt.Sprintf(config, in, sink1, sink2, socketf)
if err := ioutil.WriteFile(conf, []byte(config), 0600); err != nil {
t.Fatal(err)
} else {
logger.Trace("wrote test config", "path", conf)
}
_, cmd := testAgentCommand(t, logger)
cmd.client = client
// Kill the command 5 seconds after it starts
go func() {
select {
case <-cmd.ShutdownCh:
case <-time.After(5 * time.Second):
cmd.ShutdownCh <- struct{}{}
}
}()
originalVaultAgentAddress := os.Getenv(api.EnvVaultAgentAddress)
// Create a client that talks to the agent
os.Setenv(api.EnvVaultAgentAddress, socketf)
testClient, err := api.NewClient(api.DefaultConfig())
if err != nil {
t.Fatal(err)
}
os.Setenv(api.EnvVaultAgentAddress, originalVaultAgentAddress)
// Start the agent
go cmd.Run([]string{"-config", conf})
// Give some time for the auto-auth to complete
time.Sleep(1 * time.Second)
// Invoke lookup self through the agent
secret, err := testClient.Auth().Token().LookupSelf()
if err != nil {
t.Fatal(err)
}
if secret == nil || secret.Data == nil || secret.Data["id"].(string) == "" {
t.Fatalf("failed to perform lookup self through agent")
}
}
func TestExitAfterAuth(t *testing.T) {
logger := logging.NewVaultLogger(hclog.Trace)
coreConfig := &vault.CoreConfig{

View File

@ -39,6 +39,7 @@ type BaseCommand struct {
flagsOnce sync.Once
flagAddress string
flagAgentAddress string
flagCACert string
flagCAPath string
flagClientCert string
@ -78,6 +79,9 @@ func (c *BaseCommand) Client() (*api.Client, error) {
if c.flagAddress != "" {
config.Address = c.flagAddress
}
if c.flagAgentAddress != "" {
config.Address = c.flagAgentAddress
}
if c.flagOutputCurlString {
config.OutputCurlString = c.flagOutputCurlString
@ -220,6 +224,15 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
}
f.StringVar(addrStringVar)
agentAddrStringVar := &StringVar{
Name: "agent-address",
Target: &c.flagAgentAddress,
EnvVar: "VAULT_AGENT_ADDR",
Completion: complete.PredictAnything,
Usage: "Address of the Agent.",
}
f.StringVar(agentAddrStringVar)
f.StringVar(&StringVar{
Name: "ca-cert",
Target: &c.flagCACert,

View File

@ -72,7 +72,7 @@ func listenerWrapProxy(ln net.Listener, config map[string]interface{}) (net.List
return newLn, nil
}
func listenerWrapTLS(
func ListenerWrapTLS(
ln net.Listener,
props map[string]string,
config map[string]interface{},

View File

@ -35,7 +35,7 @@ func tcpListenerFactory(config map[string]interface{}, _ io.Writer, ui cli.Ui) (
return nil, nil, nil, err
}
ln = tcpKeepAliveListener{ln.(*net.TCPListener)}
ln = TCPKeepAliveListener{ln.(*net.TCPListener)}
ln, err = listenerWrapProxy(ln, config)
if err != nil {
@ -94,20 +94,20 @@ func tcpListenerFactory(config map[string]interface{}, _ io.Writer, ui cli.Ui) (
config["x_forwarded_for_reject_not_authorized"] = true
}
return listenerWrapTLS(ln, props, config, ui)
return ListenerWrapTLS(ln, props, config, ui)
}
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// TCPKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections. It's used by ListenAndServe and ListenAndServeTLS so
// dead TCP connections (e.g. closing laptop mid-download) eventually
// go away.
//
// This is copied directly from the Go source code.
type tcpKeepAliveListener struct {
type TCPKeepAliveListener struct {
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
func (ln TCPKeepAliveListener) Accept() (c net.Conn, err error) {
tc, err := ln.AcceptTCP()
if err != nil {
return