1
0

Push a lot of logic into Router to make a bunch of it nicer and enable a

lot of cleanup. Plumb config and calls to framework.Backend.Setup() into
logical_system and elsewhere, including tests.
This commit is contained in:
Jeff Mitchell 2015-09-04 16:58:12 -04:00
parent 76c18762aa
commit 3e713c61ac
20 changed files with 268 additions and 231 deletions

2
.gitignore vendored
View File

@ -49,6 +49,8 @@ Vagrantfile
dist/*
tags
# Editor backups
*~
*.sw[a-z]

View File

@ -8,7 +8,7 @@ import (
"github.com/hashicorp/vault/vault"
)
// RemountCommand is a Command that remounts a mounted secret backend
// MountTuneCommand is a Command that remounts a mounted secret backend
// to a new endpoint.
type MountTuneCommand struct {
Meta

View File

@ -66,7 +66,7 @@ func (c *Core) enableCredential(entry *MountEntry) error {
view := NewBarrierView(c.barrier, credentialBarrierPrefix+entry.UUID+"/")
// Create the new backend
backend, err := c.newCredentialBackend(entry.Type, view, nil)
backend, err := c.newCredentialBackend(entry.Type, c.mountEntrySysView(entry), view, nil)
if err != nil {
return err
}
@ -81,7 +81,7 @@ func (c *Core) enableCredential(entry *MountEntry) error {
// Mount the backend
path := credentialRoutePrefix + entry.Path
if err := c.router.Mount(backend, path, entry.UUID, view); err != nil {
if err := c.router.Mount(backend, path, entry, view); err != nil {
return err
}
c.logger.Printf("[INFO] core: enabled credential backend '%s' type: %s",
@ -242,7 +242,7 @@ func (c *Core) setupCredentials() error {
view = NewBarrierView(c.barrier, credentialBarrierPrefix+entry.UUID+"/")
// Initialize the backend
backend, err = c.newCredentialBackend(entry.Type, view, nil)
backend, err = c.newCredentialBackend(entry.Type, c.mountEntrySysView(entry), view, nil)
if err != nil {
c.logger.Printf(
"[ERR] core: failed to create credential entry %#v: %v",
@ -252,7 +252,7 @@ func (c *Core) setupCredentials() error {
// Mount the backend
path := credentialRoutePrefix + entry.Path
err = c.router.Mount(backend, path, entry.UUID, view)
err = c.router.Mount(backend, path, entry, view)
if err != nil {
c.logger.Printf("[ERR] core: failed to mount auth entry %#v: %v", entry, err)
return loadAuthFailed
@ -281,7 +281,7 @@ func (c *Core) teardownCredentials() error {
// newCredentialBackend is used to create and configure a new credential backend by name
func (c *Core) newCredentialBackend(
t string, view logical.Storage, conf map[string]string) (logical.Backend, error) {
t string, sysView logical.SystemView, view logical.Storage, conf map[string]string) (logical.Backend, error) {
f, ok := c.credentialBackends[t]
if !ok {
return nil, fmt.Errorf("unknown backend type: %s", t)
@ -291,12 +291,14 @@ func (c *Core) newCredentialBackend(
View: view,
Logger: c.logger,
Config: conf,
System: sysView,
}
b, err := f(config)
if err != nil {
return nil, err
}
return b, nil
}

View File

@ -220,8 +220,8 @@ type Core struct {
// out into the configured audit backends
auditBroker *AuditBroker
// systemView is the barrier view for the system backend
systemView *BarrierView
// systemBarrierView is the barrier view for the system backend
systemBarrierView *BarrierView
// expiration manager is used for managing LeaseIDs,
// renewal, expiration and revocation
@ -351,8 +351,8 @@ func NewCore(conf *CoreConfig) (*Core, error) {
logicalBackends[k] = f
}
logicalBackends["generic"] = PassthroughBackendFactory
logicalBackends["system"] = func(*logical.BackendConfig) (logical.Backend, error) {
return NewSystemBackend(c), nil
logicalBackends["system"] = func(config *logical.BackendConfig) (logical.Backend, error) {
return NewSystemBackend(c, config), nil
}
c.logicalBackends = logicalBackends
@ -360,8 +360,8 @@ func NewCore(conf *CoreConfig) (*Core, error) {
for k, f := range conf.CredentialBackends {
credentialBackends[k] = f
}
credentialBackends["token"] = func(*logical.BackendConfig) (logical.Backend, error) {
return NewTokenStore(c)
credentialBackends["token"] = func(config *logical.BackendConfig) (logical.Backend, error) {
return NewTokenStore(c, config)
}
c.credentialBackends = credentialBackends
@ -478,9 +478,9 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r
// We exclude renewal of a lease, since it does not need to be re-registered
if resp != nil && resp.Secret != nil && !strings.HasPrefix(req.Path, "sys/renew/") {
// Get the SystemView for the mount
sysView, err := c.sysViewByPath(req.Path)
if err != nil {
c.logger.Println(err)
sysView := c.router.MatchingSystemView(req.Path)
if sysView == nil {
c.logger.Println("[ERR] core: unable to retrieve system view from router")
return nil, auth, ErrInternalError
}

View File

@ -0,0 +1,54 @@
package vault
import (
"fmt"
"strings"
"time"
)
type dynamicSystemView struct {
core *Core
path string
}
func (d dynamicSystemView) DefaultLeaseTTL() (time.Duration, error) {
def, _, err := d.fetchTTLs()
if err != nil {
return 0, err
}
return def, nil
}
func (d dynamicSystemView) MaxLeaseTTL() (time.Duration, error) {
_, max, err := d.fetchTTLs()
if err != nil {
return 0, err
}
return max, nil
}
// TTLsByPath returns the default and max TTLs corresponding to a particular
// mount point, or the system default
func (d dynamicSystemView) fetchTTLs() (def, max time.Duration, retErr error) {
// Ensure we end the path in a slash
if !strings.HasSuffix(d.path, "/") {
d.path += "/"
}
me := d.core.router.MatchingMountEntry(d.path)
if me == nil {
return 0, 0, fmt.Errorf("[ERR] core: failed to get mount entry for %s", d.path)
}
def = d.core.defaultLeaseTTL
max = d.core.maxLeaseTTL
if me.Config.DefaultLeaseTTL != nil && *me.Config.DefaultLeaseTTL != 0 {
def = *me.Config.DefaultLeaseTTL
}
if me.Config.MaxLeaseTTL != nil && *me.Config.MaxLeaseTTL != 0 {
max = *me.Config.MaxLeaseTTL
}
return
}

View File

@ -78,7 +78,7 @@ func NewExpirationManager(router *Router, view *BarrierView, ts *TokenStore, log
// initialize the expiration manager
func (c *Core) setupExpiration() error {
// Create a sub-view
view := c.systemView.SubView(expirationSubPath)
view := c.systemBarrierView.SubView(expirationSubPath)
// Create the manager
mgr := NewExpirationManager(c.router, view, c.tokenStore, c.logger)

View File

@ -22,7 +22,7 @@ func TestExpiration_Restore(t *testing.T) {
noop := &NoopBackend{}
_, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "logical/")
exp.router.Mount(noop, "prod/aws/", uuid.GenerateUUID(), view)
exp.router.Mount(noop, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
paths := []string{
"prod/aws/foo",
@ -175,7 +175,7 @@ func TestExpiration_Revoke(t *testing.T) {
noop := &NoopBackend{}
_, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "logical/")
exp.router.Mount(noop, "prod/aws/", uuid.GenerateUUID(), view)
exp.router.Mount(noop, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
req := &logical.Request{
Operation: logical.ReadOperation,
@ -213,7 +213,7 @@ func TestExpiration_RevokeOnExpire(t *testing.T) {
noop := &NoopBackend{}
_, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "logical/")
exp.router.Mount(noop, "prod/aws/", uuid.GenerateUUID(), view)
exp.router.Mount(noop, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
req := &logical.Request{
Operation: logical.ReadOperation,
@ -262,7 +262,7 @@ func TestExpiration_RevokePrefix(t *testing.T) {
noop := &NoopBackend{}
_, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "logical/")
exp.router.Mount(noop, "prod/aws/", uuid.GenerateUUID(), view)
exp.router.Mount(noop, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
paths := []string{
"prod/aws/foo",
@ -322,7 +322,7 @@ func TestExpiration_RevokeByToken(t *testing.T) {
noop := &NoopBackend{}
_, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "logical/")
exp.router.Mount(noop, "prod/aws/", uuid.GenerateUUID(), view)
exp.router.Mount(noop, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
paths := []string{
"prod/aws/foo",
@ -441,7 +441,7 @@ func TestExpiration_Renew(t *testing.T) {
noop := &NoopBackend{}
_, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "logical/")
exp.router.Mount(noop, "prod/aws/", uuid.GenerateUUID(), view)
exp.router.Mount(noop, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
req := &logical.Request{
Operation: logical.ReadOperation,
@ -503,7 +503,7 @@ func TestExpiration_Renew_NotRenewable(t *testing.T) {
noop := &NoopBackend{}
_, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "logical/")
exp.router.Mount(noop, "prod/aws/", uuid.GenerateUUID(), view)
exp.router.Mount(noop, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
req := &logical.Request{
Operation: logical.ReadOperation,
@ -545,7 +545,7 @@ func TestExpiration_Renew_RevokeOnExpire(t *testing.T) {
noop := &NoopBackend{}
_, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "logical/")
exp.router.Mount(noop, "prod/aws/", uuid.GenerateUUID(), view)
exp.router.Mount(noop, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
req := &logical.Request{
Operation: logical.ReadOperation,
@ -613,7 +613,7 @@ func TestExpiration_revokeEntry(t *testing.T) {
noop := &NoopBackend{}
_, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "logical/")
exp.router.Mount(noop, "", uuid.GenerateUUID(), view)
exp.router.Mount(noop, "", &MountEntry{UUID: uuid.GenerateUUID()}, view)
le := &leaseEntry{
LeaseID: "foo/bar/1234",
@ -702,7 +702,7 @@ func TestExpiration_renewEntry(t *testing.T) {
}
_, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "logical/")
exp.router.Mount(noop, "", uuid.GenerateUUID(), view)
exp.router.Mount(noop, "", &MountEntry{UUID: uuid.GenerateUUID()}, view)
le := &leaseEntry{
LeaseID: "foo/bar/1234",
@ -764,7 +764,7 @@ func TestExpiration_renewAuthEntry(t *testing.T) {
}
_, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "auth/foo/")
exp.router.Mount(noop, "auth/foo/", uuid.GenerateUUID(), view)
exp.router.Mount(noop, "auth/foo/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
le := &leaseEntry{
LeaseID: "auth/foo/1234",

View File

@ -11,7 +11,7 @@ import (
)
// logical.Factory
func PassthroughBackendFactory(*logical.BackendConfig) (logical.Backend, error) {
func PassthroughBackendFactory(conf *logical.BackendConfig) (logical.Backend, error) {
var b PassthroughBackend
b.Backend = &framework.Backend{
Help: strings.TrimSpace(passthroughHelp),
@ -53,6 +53,11 @@ func PassthroughBackendFactory(*logical.BackendConfig) (logical.Backend, error)
},
}
if conf == nil {
return nil, fmt.Errorf("Configuation passed into backend is nil")
}
b.Backend.Setup(conf)
return b, nil
}

View File

@ -176,6 +176,12 @@ func TestPassthroughBackend_List(t *testing.T) {
}
func testPassthroughBackend() logical.Backend {
b, _ := PassthroughBackendFactory(nil)
b, _ := PassthroughBackendFactory(&logical.BackendConfig{
Logger: nil,
System: logical.StaticSystemView{
DefaultLeaseTTLVal: time.Hour * 24,
MaxLeaseTTLVal: time.Hour * 24 * 30,
},
})
return b
}

View File

@ -20,10 +20,11 @@ var (
}
)
func NewSystemBackend(core *Core) logical.Backend {
func NewSystemBackend(core *Core, config *logical.BackendConfig) logical.Backend {
b := &SystemBackend{
Core: core,
}
b.Backend = &framework.Backend{
Help: strings.TrimSpace(sysHelpRoot),
@ -346,6 +347,9 @@ func NewSystemBackend(core *Core) logical.Backend {
},
},
}
b.Backend.Setup(config)
return b.Backend
}
@ -486,9 +490,26 @@ func (b *SystemBackend) handleMountConfig(
logical.ErrInvalidRequest
}
def, max, err := b.Core.TTLsByPath(path)
if !strings.HasSuffix(path, "/") {
path += "/"
}
sysView := b.Core.router.MatchingSystemView(path)
if sysView == nil {
err := fmt.Errorf("[ERR] sys: cannot fetch sysview for path %s", path)
b.Backend.Logger().Print(err)
return handleError(err)
}
def, err := sysView.DefaultLeaseTTL()
if err != nil {
b.Backend.Logger().Printf("[ERR] sys: fetching config of path '%s' failed: %v", path, err)
b.Backend.Logger().Printf("[ERR] sys: fetching config default TTL of path '%s' failed: %v", path, err)
return handleError(err)
}
max, err := sysView.MaxLeaseTTL()
if err != nil {
b.Backend.Logger().Printf("[ERR] sys: fetching config max TTL of path '%s' failed: %v", path, err)
return handleError(err)
}
@ -516,6 +537,10 @@ func (b *SystemBackend) handleMountTune(
logical.ErrInvalidRequest
}
if !strings.HasSuffix(path, "/") {
path += "/"
}
var config MountConfig
configMap := data.Get("config").(map[string]interface{})
if configMap == nil || len(configMap) == 0 {

View File

@ -760,10 +760,24 @@ func TestSystemBackend_rotate(t *testing.T) {
func testSystemBackend(t *testing.T) logical.Backend {
c, _, _ := TestCoreUnsealed(t)
return NewSystemBackend(c)
bc := &logical.BackendConfig{
Logger: c.logger,
System: logical.StaticSystemView{
DefaultLeaseTTLVal: time.Hour * 24,
MaxLeaseTTLVal: time.Hour * 24 * 30,
},
}
return NewSystemBackend(c, bc)
}
func testCoreSystemBackend(t *testing.T) (*Core, logical.Backend, string) {
c, _, root := TestCoreUnsealed(t)
return c, NewSystemBackend(c), root
bc := &logical.BackendConfig{
Logger: c.logger,
System: logical.StaticSystemView{
DefaultLeaseTTLVal: time.Hour * 24,
MaxLeaseTTLVal: time.Hour * 24 * 30,
},
}
return c, NewSystemBackend(c, bc), root
}

View File

@ -40,27 +40,6 @@ var (
}
)
type dynamicSystemView struct {
core *Core
path string
}
func (d dynamicSystemView) DefaultLeaseTTL() (time.Duration, error) {
def, _, err := d.core.TTLsByPath(d.path)
if err != nil {
return 0, err
}
return def, nil
}
func (d dynamicSystemView) MaxLeaseTTL() (time.Duration, error) {
_, max, err := d.core.TTLsByPath(d.path)
if err != nil {
return 0, err
}
return max, nil
}
// MountTable is used to represent the internal mount table
type MountTable struct {
// This lock should be held whenever modifying the Entries field.
@ -185,12 +164,7 @@ func (c *Core) mount(me *MountEntry) error {
me.UUID = uuid.GenerateUUID()
view := NewBarrierView(c.barrier, backendBarrierPrefix+me.UUID+"/")
// Create the new backend
sysView, err := c.mountEntrySysView(me)
if err != nil {
return err
}
backend, err := c.newLogicalBackend(me.Type, sysView, view, nil)
backend, err := c.newLogicalBackend(me.Type, c.mountEntrySysView(me), view, nil)
if err != nil {
return err
}
@ -204,7 +178,7 @@ func (c *Core) mount(me *MountEntry) error {
c.mounts = newTable
// Mount the backend
if err := c.router.Mount(backend, me.Path, me.UUID, view); err != nil {
if err := c.router.Mount(backend, me.Path, me, view); err != nil {
return err
}
c.logger.Printf("[INFO] core: mounted '%s' type: %s", me.Path, me.Type)
@ -394,51 +368,44 @@ func (c *Core) tuneMount(path string, config MountConfig) error {
// Prevent protected paths from being changed
for _, p := range protectedMounts {
if strings.HasPrefix(path, p) {
return fmt.Errorf("cannot tune '%s'", path)
return fmt.Errorf("[ERR] core: cannot tune '%s'", path)
}
}
// Verify exact match of the route
match := c.router.MatchingMount(path)
if match == "" || path != match {
return fmt.Errorf("no matching mount at '%s'", path)
me := c.router.MatchingMountEntry(path)
if me == nil {
return fmt.Errorf("[ERR] core: no matching mount at '%s'", path)
}
// Find and modify mount
for _, ent := range c.mounts.Entries {
if ent.Path == path {
if config.MaxLeaseTTL != nil {
if *ent.Config.DefaultLeaseTTL != 0 {
if *config.MaxLeaseTTL < *ent.Config.DefaultLeaseTTL {
return fmt.Errorf("Given backend max lease TTL of %d less than backend default lease TTL of %d",
*config.MaxLeaseTTL, *ent.Config.DefaultLeaseTTL)
}
}
if *config.MaxLeaseTTL == 0 {
*ent.Config.MaxLeaseTTL = 0
} else {
ent.Config.MaxLeaseTTL = config.MaxLeaseTTL
}
if config.MaxLeaseTTL != nil {
if *me.Config.DefaultLeaseTTL != 0 {
if *config.MaxLeaseTTL < *me.Config.DefaultLeaseTTL {
return fmt.Errorf("Given backend max lease TTL of %d less than backend default lease TTL of %d",
*config.MaxLeaseTTL, *me.Config.DefaultLeaseTTL)
}
if config.DefaultLeaseTTL != nil {
if *ent.Config.MaxLeaseTTL == 0 {
if *config.DefaultLeaseTTL > c.maxLeaseTTL {
return fmt.Errorf("Given default lease TTL of %d greater than system default lease TTL of %d",
*config.DefaultLeaseTTL, c.maxLeaseTTL)
}
} else {
if *ent.Config.MaxLeaseTTL != 0 && *ent.Config.MaxLeaseTTL < *config.DefaultLeaseTTL {
return fmt.Errorf("Given default lease TTL of %d greater than backend max lease TTL of %d",
*config.DefaultLeaseTTL, *ent.Config.MaxLeaseTTL)
}
}
if *config.DefaultLeaseTTL == 0 {
*ent.Config.DefaultLeaseTTL = 0
} else {
ent.Config.DefaultLeaseTTL = config.DefaultLeaseTTL
}
}
if *config.MaxLeaseTTL == 0 {
*me.Config.MaxLeaseTTL = 0
} else {
me.Config.MaxLeaseTTL = config.MaxLeaseTTL
}
}
if config.DefaultLeaseTTL != nil {
if *me.Config.MaxLeaseTTL == 0 {
if *config.DefaultLeaseTTL > c.maxLeaseTTL {
return fmt.Errorf("Given default lease TTL of %d greater than system default lease TTL of %d",
*config.DefaultLeaseTTL, c.maxLeaseTTL)
}
break
} else {
if *me.Config.MaxLeaseTTL != 0 && *me.Config.MaxLeaseTTL < *config.DefaultLeaseTTL {
return fmt.Errorf("Given default lease TTL of %d greater than backend max lease TTL of %d",
*config.DefaultLeaseTTL, *me.Config.MaxLeaseTTL)
}
}
if *config.DefaultLeaseTTL == 0 {
*me.Config.DefaultLeaseTTL = 0
} else {
me.Config.DefaultLeaseTTL = config.DefaultLeaseTTL
}
}
@ -508,6 +475,7 @@ func (c *Core) persistMounts(table *MountTable) error {
func (c *Core) setupMounts() error {
var backend logical.Backend
var view *BarrierView
var err error
for _, entry := range c.mounts.Entries {
// Initialize the backend, special casing for system
barrierPath := backendBarrierPrefix + entry.UUID + "/"
@ -520,11 +488,7 @@ func (c *Core) setupMounts() error {
// Initialize the backend
// Create the new backend
sysView, err := c.mountEntrySysView(entry)
if err != nil {
return err
}
backend, err = c.newLogicalBackend(entry.Type, sysView, view, nil)
backend, err = c.newLogicalBackend(entry.Type, c.mountEntrySysView(entry), view, nil)
if err != nil {
c.logger.Printf(
"[ERR] core: failed to create mount entry %#v: %v",
@ -533,11 +497,11 @@ func (c *Core) setupMounts() error {
}
if entry.Type == "system" {
c.systemView = view
c.systemBarrierView = view
}
// Mount the backend
err = c.router.Mount(backend, entry.Path, entry.UUID, view)
err = c.router.Mount(backend, entry.Path, entry, view)
if err != nil {
c.logger.Printf("[ERR] core: failed to mount entry %#v: %v", entry, err)
return errLoadMountsFailed
@ -556,7 +520,7 @@ func (c *Core) setupMounts() error {
func (c *Core) unloadMounts() error {
c.mounts = nil
c.router = NewRouter()
c.systemView = nil
c.systemBarrierView = nil
return nil
}
@ -582,82 +546,13 @@ func (c *Core) newLogicalBackend(t string, sysView logical.SystemView, view logi
}
// mountEntrySysView creates a logical.SystemView from global and
// mount-specific entries
func (c *Core) mountEntrySysView(me *MountEntry) (logical.SystemView, error) {
if me == nil {
return nil, fmt.Errorf("[ERR] core: nil MountEntry when generating SystemView")
}
sysView := dynamicSystemView{
// mount-specific entries; because this should be called when setting
// up a mountEntry, it doesn't check to ensure that me is not nil
func (c *Core) mountEntrySysView(me *MountEntry) logical.SystemView {
return dynamicSystemView{
core: c,
path: me.Path,
}
return sysView, nil
}
// sysViewByPath is a simple helper for MountEntrySysView
func (c *Core) sysViewByPath(path string) (logical.SystemView, error) {
// Ensure we end the path in a slash
if !strings.HasSuffix(path, "/") {
path += "/"
}
me, err := c.mountEntryByPath(path)
if err != nil {
return nil, err
}
return c.mountEntrySysView(me)
}
// mountEntryByPath searches across all tables to find the MountEntry
func (c *Core) mountEntryByPath(path string) (*MountEntry, error) {
// Ensure we end the path in a slash
if !strings.HasSuffix(path, "/") {
path += "/"
}
pathSep := strings.IndexRune(path, '/')
if pathSep == -1 {
return nil, fmt.Errorf("[ERR] core: failed to find separator for path %s", path)
}
me := c.mounts.Find(path[0 : pathSep+1])
if me == nil {
me = c.auth.Find(path[0 : pathSep+1])
}
if me == nil {
me = c.audit.Find(path[0 : pathSep+1])
}
if me == nil {
return nil, fmt.Errorf("[ERR] core: failed to find mount entry for path %s", path)
}
return me, nil
}
// TTLsByPath returns the default and max TTLs corresponding to a particular
// mount point, or the system default
func (c *Core) TTLsByPath(path string) (def, max time.Duration, retErr error) {
// Ensure we end the path in a slash
if !strings.HasSuffix(path, "/") {
path += "/"
}
me, err := c.mountEntryByPath(path)
if err != nil {
return 0, 0, err
}
def = c.defaultLeaseTTL
max = c.maxLeaseTTL
if me.Config.DefaultLeaseTTL != nil && *me.Config.DefaultLeaseTTL != 0 {
def = *me.Config.DefaultLeaseTTL
}
if me.Config.MaxLeaseTTL != nil && *me.Config.MaxLeaseTTL != 0 {
max = *me.Config.MaxLeaseTTL
}
return
}
// defaultMountTable creates a default mount table

View File

@ -192,7 +192,7 @@ func TestCore_Unmount_Cleanup(t *testing.T) {
func TestCore_Remount(t *testing.T) {
c, key, _ := TestCoreUnsealed(t)
err := c.remount("secret", "foo", MountConfig{})
err := c.remount("secret", "foo")
if err != nil {
t.Fatalf("err: %v", err)
}
@ -280,7 +280,7 @@ func TestCore_Remount_Cleanup(t *testing.T) {
}
// Remount, this should cleanup
if err := c.remount("test/", "new/", MountConfig{}); err != nil {
if err := c.remount("test/", "new/"); err != nil {
t.Fatalf("err: %v", err)
}
@ -309,7 +309,7 @@ func TestCore_Remount_Cleanup(t *testing.T) {
func TestCore_Remount_Protected(t *testing.T) {
c, _, _ := TestCoreUnsealed(t)
err := c.remount("sys", "foo", MountConfig{})
err := c.remount("sys", "foo")
if err.Error() != "cannot remount 'sys/'" {
t.Fatalf("err: %v", err)
}

View File

@ -46,7 +46,7 @@ func NewPolicyStore(view *BarrierView) *PolicyStore {
// when the vault is being unsealed.
func (c *Core) setupPolicyStore() error {
// Create a sub-view
view := c.systemView.SubView(policySubPath)
view := c.systemBarrierView.SubView(policySubPath)
// Create the policy store
c.policy = NewPolicyStore(view)

View File

@ -21,7 +21,7 @@ func mockRollback(t *testing.T) (*RollbackManager, *NoopBackend) {
Path: "foo",
},
}
if err := router.Mount(backend, "foo", uuid.GenerateUUID(), nil); err != nil {
if err := router.Mount(backend, "foo", &MountEntry{UUID: uuid.GenerateUUID()}, nil); err != nil {
t.Fatalf("err: %s", err)
}

View File

@ -26,24 +26,24 @@ func NewRouter() *Router {
return r
}
// mountEntry is used to represent a mount point
type mountEntry struct {
// routeEntry is used to represent a mount point in the router
type routeEntry struct {
tainted bool
salt string
backend logical.Backend
mountEntry *MountEntry
view *BarrierView
rootPaths *radix.Tree
loginPaths *radix.Tree
}
// SaltID is used to apply a salt and hash to an ID to make sure its not reversable
func (me *mountEntry) SaltID(id string) string {
return salt.SaltID(me.salt, id, salt.SHA1Hash)
func (re *routeEntry) SaltID(id string) string {
return salt.SaltID(re.mountEntry.UUID, id, salt.SHA1Hash)
}
// Mount is used to expose a logical backend at a given prefix, using a unique salt,
// and the barrier view for that path.
func (r *Router) Mount(backend logical.Backend, prefix, salt string, view *BarrierView) error {
func (r *Router) Mount(backend logical.Backend, prefix string, mountEntry *MountEntry, view *BarrierView) error {
r.l.Lock()
defer r.l.Unlock()
@ -59,14 +59,15 @@ func (r *Router) Mount(backend logical.Backend, prefix, salt string, view *Barri
}
// Create a mount entry
me := &mountEntry{
re := &routeEntry{
tainted: false,
backend: backend,
mountEntry: mountEntry,
view: view,
rootPaths: pathsToRadix(paths.Root),
loginPaths: pathsToRadix(paths.Unauthenticated),
}
r.root.Insert(prefix, me)
r.root.Insert(prefix, re)
return nil
}
@ -91,12 +92,8 @@ func (r *Router) Remount(src, dst string) error {
// Update the mount point
r.root.Delete(src)
mountEntry, ok := raw.(*mountEntry)
if !ok {
return fmt.Errorf("Unable to retrieve mount entry at '%s'", src)
}
sysView := mountEntry.backend.System()
dynSysView, ok := sysView.(dynamicSystemView)
routeEntry := raw.(*routeEntry)
dynSysView, ok := routeEntry.backend.System().(dynamicSystemView)
if ok {
dynSysView.path = dst
}
@ -111,7 +108,7 @@ func (r *Router) Taint(path string) error {
defer r.l.Unlock()
_, raw, ok := r.root.LongestPrefix(path)
if ok {
raw.(*mountEntry).tainted = true
raw.(*routeEntry).tainted = true
}
return nil
}
@ -122,7 +119,7 @@ func (r *Router) Untaint(path string) error {
defer r.l.Unlock()
_, raw, ok := r.root.LongestPrefix(path)
if ok {
raw.(*mountEntry).tainted = false
raw.(*routeEntry).tainted = false
}
return nil
}
@ -146,7 +143,29 @@ func (r *Router) MatchingView(path string) *BarrierView {
if !ok {
return nil
}
return raw.(*mountEntry).view
return raw.(*routeEntry).view
}
// MatchingMountEntry returns the MountEntry used for a path
func (r *Router) MatchingMountEntry(path string) *MountEntry {
r.l.RLock()
_, raw, ok := r.root.LongestPrefix(path)
r.l.RUnlock()
if !ok {
return nil
}
return raw.(*routeEntry).mountEntry
}
// MatchingSystemView returns the SystemView used for a path
func (r *Router) MatchingSystemView(path string) logical.SystemView {
r.l.RLock()
_, raw, ok := r.root.LongestPrefix(path)
r.l.RUnlock()
if !ok {
return nil
}
return raw.(*routeEntry).backend.System()
}
// Route is used to route a given request
@ -166,11 +185,11 @@ func (r *Router) Route(req *logical.Request) (*logical.Response, error) {
}
defer metrics.MeasureSince([]string{"route", string(req.Operation),
strings.Replace(mount, "/", "-", -1)}, time.Now())
me := raw.(*mountEntry)
re := raw.(*routeEntry)
// If the path is tainted, we reject any operation except for
// Rollback and Revoke
if me.tainted {
if re.tainted {
switch req.Operation {
case logical.RevokeOperation, logical.RollbackOperation:
default:
@ -190,12 +209,12 @@ func (r *Router) Route(req *logical.Request) (*logical.Response, error) {
}
// Attach the storage view for the request
req.Storage = me.view
req.Storage = re.view
// Hash the request token unless this is the token backend
clientToken := req.ClientToken
if !strings.HasPrefix(original, "auth/token/") {
req.ClientToken = me.SaltID(req.ClientToken)
req.ClientToken = re.SaltID(req.ClientToken)
}
// If the request is not a login path, then clear the connection
@ -214,7 +233,7 @@ func (r *Router) Route(req *logical.Request) (*logical.Response, error) {
}()
// Invoke the backend
return me.backend.HandleRequest(req)
return re.backend.HandleRequest(req)
}
// RootPath checks if the given path requires root privileges
@ -225,13 +244,13 @@ func (r *Router) RootPath(path string) bool {
if !ok {
return false
}
me := raw.(*mountEntry)
re := raw.(*routeEntry)
// Trim to get remaining path
remain := strings.TrimPrefix(path, mount)
// Check the rootPaths of this backend
match, raw, ok := me.rootPaths.LongestPrefix(remain)
match, raw, ok := re.rootPaths.LongestPrefix(remain)
if !ok {
return false
}
@ -254,13 +273,13 @@ func (r *Router) LoginPath(path string) bool {
if !ok {
return false
}
me := raw.(*mountEntry)
re := raw.(*routeEntry)
// Trim to get remaining path
remain := strings.TrimPrefix(path, mount)
// Check the loginPaths of this backend
match, raw, ok := me.loginPaths.LongestPrefix(remain)
match, raw, ok := re.loginPaths.LongestPrefix(remain)
if !ok {
return false
}

View File

@ -55,12 +55,12 @@ func TestRouter_Mount(t *testing.T) {
view := NewBarrierView(barrier, "logical/")
n := &NoopBackend{}
err := r.Mount(n, "prod/aws/", uuid.GenerateUUID(), view)
err := r.Mount(n, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
if err != nil {
t.Fatalf("err: %v", err)
}
err = r.Mount(n, "prod/aws/", uuid.GenerateUUID(), view)
err = r.Mount(n, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
if !strings.Contains(err.Error(), "cannot mount under existing mount") {
t.Fatalf("err: %v", err)
}
@ -104,7 +104,7 @@ func TestRouter_Unmount(t *testing.T) {
view := NewBarrierView(barrier, "logical/")
n := &NoopBackend{}
err := r.Mount(n, "prod/aws/", uuid.GenerateUUID(), view)
err := r.Mount(n, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -129,7 +129,7 @@ func TestRouter_Remount(t *testing.T) {
view := NewBarrierView(barrier, "logical/")
n := &NoopBackend{}
err := r.Mount(n, "prod/aws/", uuid.GenerateUUID(), view)
err := r.Mount(n, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -177,7 +177,7 @@ func TestRouter_RootPath(t *testing.T) {
"policy/*",
},
}
err := r.Mount(n, "prod/aws/", uuid.GenerateUUID(), view)
err := r.Mount(n, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -215,7 +215,7 @@ func TestRouter_LoginPath(t *testing.T) {
"oauth/*",
},
}
err := r.Mount(n, "auth/foo/", uuid.GenerateUUID(), view)
err := r.Mount(n, "auth/foo/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -246,7 +246,7 @@ func TestRouter_Taint(t *testing.T) {
view := NewBarrierView(barrier, "logical/")
n := &NoopBackend{}
err := r.Mount(n, "prod/aws/", uuid.GenerateUUID(), view)
err := r.Mount(n, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -285,7 +285,7 @@ func TestRouter_Untaint(t *testing.T) {
view := NewBarrierView(barrier, "logical/")
n := &NoopBackend{}
err := r.Mount(n, "prod/aws/", uuid.GenerateUUID(), view)
err := r.Mount(n, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
if err != nil {
t.Fatalf("err: %v", err)
}

View File

@ -62,10 +62,12 @@ func TestCore(t *testing.T) *Core {
},
}
noopBackends := make(map[string]logical.Factory)
noopBackends["noop"] = func(*logical.BackendConfig) (logical.Backend, error) {
return new(framework.Backend), nil
noopBackends["noop"] = func(config *logical.BackendConfig) (logical.Backend, error) {
b := new(framework.Backend)
b.Setup(config)
return b, nil
}
noopBackends["http"] = func(*logical.BackendConfig) (logical.Backend, error) {
noopBackends["http"] = func(config *logical.BackendConfig) (logical.Backend, error) {
return new(rawHTTP), nil
}
logicalBackends := make(map[string]logical.Factory)

View File

@ -48,9 +48,9 @@ type TokenStore struct {
// NewTokenStore is used to construct a token store that is
// backed by the given barrier view.
func NewTokenStore(c *Core) (*TokenStore, error) {
func NewTokenStore(c *Core, config *logical.BackendConfig) (*TokenStore, error) {
// Create a sub-view
view := c.systemView.SubView(tokenSubPath)
view := c.systemBarrierView.SubView(tokenSubPath)
// Initialize the store
t := &TokenStore{
@ -203,6 +203,8 @@ func NewTokenStore(c *Core) (*TokenStore, error) {
},
}
t.Backend.Setup(config)
return t, nil
}

View File

@ -10,19 +10,30 @@ import (
"github.com/hashicorp/vault/logical"
)
func getBackendConfig(c *Core) *logical.BackendConfig {
return &logical.BackendConfig{
Logger: c.logger,
System: logical.StaticSystemView{
DefaultLeaseTTLVal: time.Hour * 24,
MaxLeaseTTLVal: time.Hour * 24 * 30,
},
}
}
func mockTokenStore(t *testing.T) (*Core, *TokenStore, string) {
logger := log.New(os.Stderr, "", log.LstdFlags)
c, _, root := TestCoreUnsealed(t)
ts, err := NewTokenStore(c)
ts, err := NewTokenStore(c, getBackendConfig(c))
if err != nil {
t.Fatalf("err: %v", err)
}
router := NewRouter()
router.Mount(ts, "auth/token/", "", ts.view)
router.Mount(ts, "auth/token/", &MountEntry{UUID: ""}, ts.view)
view := c.systemView.SubView(expirationSubPath)
view := c.systemBarrierView.SubView(expirationSubPath)
exp := NewExpirationManager(router, view, ts, logger)
ts.SetExpirationManager(exp)
return c, ts, root
@ -68,7 +79,7 @@ func TestTokenStore_CreateLookup(t *testing.T) {
}
// New store should share the salt
ts2, err := NewTokenStore(c)
ts2, err := NewTokenStore(c, getBackendConfig(c))
if err != nil {
t.Fatalf("err: %v", err)
}
@ -107,7 +118,7 @@ func TestTokenStore_CreateLookup_ProvidedID(t *testing.T) {
}
// New store should share the salt
ts2, err := NewTokenStore(c)
ts2, err := NewTokenStore(c, getBackendConfig(c))
if err != nil {
t.Fatalf("err: %v", err)
}
@ -219,7 +230,7 @@ func TestTokenStore_Revoke_Leases(t *testing.T) {
// Mount a noop backend
noop := &NoopBackend{}
ts.expiration.router.Mount(noop, "", "", nil)
ts.expiration.router.Mount(noop, "", &MountEntry{UUID: ""}, nil)
ent := &TokenEntry{Path: "test", Policies: []string{"dev", "ops"}}
if err := ts.Create(ent); err != nil {