1
0
vault-redux/vault/audit.go
2024-01-02 10:36:20 -08:00

583 lines
15 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package vault
import (
"context"
"crypto/sha256"
"errors"
"fmt"
"strings"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/jsonutil"
"github.com/hashicorp/vault/sdk/helper/salt"
"github.com/hashicorp/vault/sdk/logical"
)
const (
// coreAuditConfigPath is used to store the audit configuration.
// Audit configuration is protected within the Vault itself, which means it
// can only be viewed or modified after an unseal.
coreAuditConfigPath = "core/audit"
// coreLocalAuditConfigPath is used to store audit information for local
// (non-replicated) mounts
coreLocalAuditConfigPath = "core/local-audit"
// auditBarrierPrefix is the prefix to the UUID used in the
// barrier view for the audit backends.
auditBarrierPrefix = "audit/"
// auditTableType is the value we expect to find for the audit table and
// corresponding entries
auditTableType = "audit"
)
// loadAuditFailed if loading audit tables encounters an error
var errLoadAuditFailed = errors.New("failed to setup audit table")
func (c *Core) generateAuditTestProbe() (*logical.LogInput, error) {
requestId, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
return &logical.LogInput{
Type: "request",
Auth: nil,
Request: &logical.Request{
ID: requestId,
Operation: "update",
Path: "sys/audit/test",
},
Response: nil,
OuterErr: nil,
}, nil
}
// enableAudit is used to enable a new audit backend
func (c *Core) enableAudit(ctx context.Context, entry *MountEntry, updateStorage bool) error {
// Ensure we end the path in a slash
if !strings.HasSuffix(entry.Path, "/") {
entry.Path += "/"
}
// Ensure there is a name
if entry.Path == "/" {
return fmt.Errorf("backend path must be specified")
}
// Update the audit table
c.auditLock.Lock()
defer c.auditLock.Unlock()
// Look for matching name
for _, ent := range c.audit.Entries {
switch {
// Existing is sql/mysql/ new is sql/ or
// existing is sql/ and new is sql/mysql/
case strings.HasPrefix(ent.Path, entry.Path):
fallthrough
case strings.HasPrefix(entry.Path, ent.Path):
return fmt.Errorf("path already in use")
}
}
// Generate a new UUID and view
if entry.UUID == "" {
entryUUID, err := uuid.GenerateUUID()
if err != nil {
return err
}
entry.UUID = entryUUID
}
if entry.Accessor == "" {
accessor, err := c.generateMountAccessor("audit_" + entry.Type)
if err != nil {
return err
}
entry.Accessor = accessor
}
viewPath := entry.ViewPath()
view := NewBarrierView(c.barrier, viewPath)
addAuditPathChecker(c, entry, view, viewPath)
origViewReadOnlyErr := view.getReadOnlyErr()
// Mark the view as read-only until the mounting is complete and
// ensure that it is reset after. This ensures that there will be no
// writes during the construction of the backend.
view.setReadOnlyErr(logical.ErrSetupReadOnly)
defer view.setReadOnlyErr(origViewReadOnlyErr)
// Lookup the new backend
backend, err := c.newAuditBackend(ctx, entry, view, entry.Options)
if err != nil {
return err
}
if backend == nil {
return fmt.Errorf("nil audit backend of type %q returned from factory", entry.Type)
}
if entry.Options["skip_test"] != "true" {
// Test the new audit device and report failure if it doesn't work.
testProbe, err := c.generateAuditTestProbe()
if err != nil {
return err
}
err = backend.LogTestMessage(ctx, testProbe, entry.Options)
if err != nil {
c.logger.Error("new audit backend failed test", "path", entry.Path, "type", entry.Type, "error", err)
return fmt.Errorf("audit backend failed test message: %w", err)
}
}
newTable := c.audit.shallowClone()
newTable.Entries = append(newTable.Entries, entry)
ns, err := namespace.FromContext(ctx)
if err != nil {
return err
}
entry.NamespaceID = ns.ID
entry.namespace = ns
if updateStorage {
if err := c.persistAudit(ctx, newTable, entry.Local); err != nil {
return errors.New("failed to update audit table")
}
}
c.audit = newTable
// Register the backend
c.auditBroker.Register(entry.Path, backend, view, entry.Local)
if c.logger.IsInfo() {
c.logger.Info("enabled audit backend", "path", entry.Path, "type", entry.Type)
}
return nil
}
// disableAudit is used to disable an existing audit backend
func (c *Core) disableAudit(ctx context.Context, path string, updateStorage bool) (bool, error) {
// Ensure we end the path in a slash
if !strings.HasSuffix(path, "/") {
path += "/"
}
// Ensure there is a name
if path == "/" {
return false, fmt.Errorf("backend path must be specified")
}
// Remove the entry from the mount table
c.auditLock.Lock()
defer c.auditLock.Unlock()
newTable := c.audit.shallowClone()
entry, err := newTable.remove(ctx, path)
if err != nil {
return false, err
}
// Ensure there was a match
if entry == nil {
return false, fmt.Errorf("no matching backend")
}
c.removeAuditReloadFunc(entry)
// When unmounting all entries the JSON code will load back up from storage
// as a nil slice, which kills tests...just set it nil explicitly
if len(newTable.Entries) == 0 {
newTable.Entries = nil
}
if updateStorage {
// Update the audit table
if err := c.persistAudit(ctx, newTable, entry.Local); err != nil {
return true, errors.New("failed to update audit table")
}
}
c.audit = newTable
// Unmount the backend
c.auditBroker.Deregister(path)
if c.logger.IsInfo() {
c.logger.Info("disabled audit backend", "path", path)
}
removeAuditPathChecker(c, entry)
return true, nil
}
// loadAudits is invoked as part of postUnseal to load the audit table
func (c *Core) loadAudits(ctx context.Context) error {
auditTable := &MountTable{}
localAuditTable := &MountTable{}
// Load the existing audit table
raw, err := c.barrier.Get(ctx, coreAuditConfigPath)
if err != nil {
c.logger.Error("failed to read audit table", "error", err)
return errLoadAuditFailed
}
rawLocal, err := c.barrier.Get(ctx, coreLocalAuditConfigPath)
if err != nil {
c.logger.Error("failed to read local audit table", "error", err)
return errLoadAuditFailed
}
c.auditLock.Lock()
defer c.auditLock.Unlock()
if raw != nil {
if err := jsonutil.DecodeJSON(raw.Value, auditTable); err != nil {
c.logger.Error("failed to decode audit table", "error", err)
return errLoadAuditFailed
}
c.audit = auditTable
}
var needPersist bool
if c.audit == nil {
c.audit = defaultAuditTable()
needPersist = true
}
if rawLocal != nil {
if err := jsonutil.DecodeJSON(rawLocal.Value, localAuditTable); err != nil {
c.logger.Error("failed to decode local audit table", "error", err)
return errLoadAuditFailed
}
if localAuditTable != nil && len(localAuditTable.Entries) > 0 {
c.audit.Entries = append(c.audit.Entries, localAuditTable.Entries...)
}
}
// Upgrade to typed auth table
if c.audit.Type == "" {
c.audit.Type = auditTableType
needPersist = true
}
// Upgrade to table-scoped entries
for _, entry := range c.audit.Entries {
if entry.Table == "" {
entry.Table = c.audit.Type
needPersist = true
}
if entry.Accessor == "" {
accessor, err := c.generateMountAccessor("audit_" + entry.Type)
if err != nil {
return err
}
entry.Accessor = accessor
needPersist = true
}
if entry.NamespaceID == "" {
entry.NamespaceID = namespace.RootNamespaceID
needPersist = true
}
// Get the namespace from the namespace ID and load it in memory
ns, err := NamespaceByID(ctx, entry.NamespaceID, c)
if err != nil {
return err
}
if ns == nil {
return namespace.ErrNoNamespace
}
entry.namespace = ns
}
if !needPersist || c.perfStandby {
return nil
}
if err := c.persistAudit(ctx, c.audit, false); err != nil {
return errLoadAuditFailed
}
return nil
}
// persistAudit is used to persist the audit table after modification
func (c *Core) persistAudit(ctx context.Context, table *MountTable, localOnly bool) error {
if table.Type != auditTableType {
c.logger.Error("given table to persist has wrong type", "actual_type", table.Type, "expected_type", auditTableType)
return fmt.Errorf("invalid table type given, not persisting")
}
nonLocalAudit := &MountTable{
Type: auditTableType,
}
localAudit := &MountTable{
Type: auditTableType,
}
for _, entry := range table.Entries {
if entry.Table != table.Type {
c.logger.Error("given entry to persist in audit table has wrong table value", "path", entry.Path, "entry_table_type", entry.Table, "actual_type", table.Type)
return fmt.Errorf("invalid audit entry found, not persisting")
}
if entry.Local {
localAudit.Entries = append(localAudit.Entries, entry)
} else {
nonLocalAudit.Entries = append(nonLocalAudit.Entries, entry)
}
}
if !localOnly {
// Marshal the table
compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalAudit, nil)
if err != nil {
c.logger.Error("failed to encode and/or compress audit table", "error", err)
return err
}
// Create an entry
entry := &logical.StorageEntry{
Key: coreAuditConfigPath,
Value: compressedBytes,
}
// Write to the physical backend
if err := c.barrier.Put(ctx, entry); err != nil {
c.logger.Error("failed to persist audit table", "error", err)
return err
}
}
// Repeat with local audit
compressedBytes, err := jsonutil.EncodeJSONAndCompress(localAudit, nil)
if err != nil {
c.logger.Error("failed to encode and/or compress local audit table", "error", err)
return err
}
entry := &logical.StorageEntry{
Key: coreLocalAuditConfigPath,
Value: compressedBytes,
}
if err := c.barrier.Put(ctx, entry); err != nil {
c.logger.Error("failed to persist local audit table", "error", err)
return err
}
return nil
}
// setupAudit is invoked after we've loaded the audit able to
// initialize the audit backends
func (c *Core) setupAudits(ctx context.Context) error {
brokerLogger := c.baseLogger.Named("audit")
c.AddLogger(brokerLogger)
broker := NewAuditBroker(brokerLogger)
c.auditLock.Lock()
defer c.auditLock.Unlock()
var successCount int
for _, entry := range c.audit.Entries {
// Create a barrier view using the UUID
viewPath := entry.ViewPath()
view := NewBarrierView(c.barrier, viewPath)
addAuditPathChecker(c, entry, view, viewPath)
origViewReadOnlyErr := view.getReadOnlyErr()
// Mark the view as read-only until the mounting is complete and
// ensure that it is reset after. This ensures that there will be no
// writes during the construction of the backend.
view.setReadOnlyErr(logical.ErrSetupReadOnly)
c.postUnsealFuncs = append(c.postUnsealFuncs, func() {
view.setReadOnlyErr(origViewReadOnlyErr)
})
// Initialize the backend
backend, err := c.newAuditBackend(ctx, entry, view, entry.Options)
if err != nil {
c.logger.Error("failed to create audit entry", "path", entry.Path, "error", err)
continue
}
if backend == nil {
c.logger.Error("created audit entry was nil", "path", entry.Path, "type", entry.Type)
continue
}
// Mount the backend
broker.Register(entry.Path, backend, view, entry.Local)
successCount++
}
if len(c.audit.Entries) > 0 && successCount == 0 {
return errLoadAuditFailed
}
c.auditBroker = broker
return nil
}
// teardownAudit is used before we seal the vault to reset the audit
// backends to their unloaded state. This is reversed by loadAudits.
func (c *Core) teardownAudits() error {
c.auditLock.Lock()
defer c.auditLock.Unlock()
if c.audit != nil {
for _, entry := range c.audit.Entries {
c.removeAuditReloadFunc(entry)
removeAuditPathChecker(c, entry)
}
}
c.audit = nil
c.auditBroker = nil
return nil
}
// removeAuditReloadFunc removes the reload func from the working set. The
// audit lock needs to be held before calling this.
func (c *Core) removeAuditReloadFunc(entry *MountEntry) {
switch entry.Type {
case "file":
key := "audit_file|" + entry.Path
c.reloadFuncsLock.Lock()
if c.logger.IsDebug() {
c.baseLogger.Named("audit").Debug("removing reload function", "path", entry.Path)
}
delete(c.reloadFuncs, key)
c.reloadFuncsLock.Unlock()
}
}
// newAuditBackend is used to create and configure a new audit backend by name
func (c *Core) newAuditBackend(ctx context.Context, entry *MountEntry, view logical.Storage, conf map[string]string) (audit.Backend, error) {
f, ok := c.auditBackends[entry.Type]
if !ok {
return nil, fmt.Errorf("unknown backend type: %q", entry.Type)
}
saltConfig := &salt.Config{
HMAC: sha256.New,
HMACType: "hmac-sha256",
Location: salt.DefaultLocation,
}
be, err := f(ctx, &audit.BackendConfig{
SaltView: view,
SaltConfig: saltConfig,
Config: conf,
})
if err != nil {
return nil, err
}
if be == nil {
return nil, fmt.Errorf("nil backend returned from %q factory function", entry.Type)
}
auditLogger := c.baseLogger.Named("audit")
c.AddLogger(auditLogger)
switch entry.Type {
case "file":
key := "audit_file|" + entry.Path
c.reloadFuncsLock.Lock()
if auditLogger.IsDebug() {
auditLogger.Debug("adding reload function", "path", entry.Path)
if entry.Options != nil {
auditLogger.Debug("file backend options", "path", entry.Path, "file_path", entry.Options["file_path"])
}
}
c.reloadFuncs[key] = append(c.reloadFuncs[key], func() error {
if auditLogger.IsInfo() {
auditLogger.Info("reloading file audit backend", "path", entry.Path)
}
return be.Reload(ctx)
})
c.reloadFuncsLock.Unlock()
case "socket":
if auditLogger.IsDebug() {
if entry.Options != nil {
auditLogger.Debug("socket backend options", "path", entry.Path, "address", entry.Options["address"], "socket type", entry.Options["socket_type"])
}
}
case "syslog":
if auditLogger.IsDebug() {
if entry.Options != nil {
auditLogger.Debug("syslog backend options", "path", entry.Path, "facility", entry.Options["facility"], "tag", entry.Options["tag"])
}
}
}
return be, err
}
// defaultAuditTable creates a default audit table
func defaultAuditTable() *MountTable {
table := &MountTable{
Type: auditTableType,
}
return table
}
type AuditLogger interface {
AuditRequest(ctx context.Context, input *logical.LogInput) error
AuditResponse(ctx context.Context, input *logical.LogInput) error
}
type basicAuditor struct {
c *Core
}
func (b *basicAuditor) AuditRequest(ctx context.Context, input *logical.LogInput) error {
if b.c.auditBroker == nil {
return consts.ErrSealed
}
return b.c.auditBroker.LogRequest(ctx, input, b.c.auditedHeaders)
}
func (b *basicAuditor) AuditResponse(ctx context.Context, input *logical.LogInput) error {
if b.c.auditBroker == nil {
return consts.ErrSealed
}
return b.c.auditBroker.LogResponse(ctx, input, b.c.auditedHeaders)
}
type genericAuditor struct {
c *Core
mountType string
namespace *namespace.Namespace
}
func (g genericAuditor) AuditRequest(ctx context.Context, input *logical.LogInput) error {
ctx = namespace.ContextWithNamespace(ctx, g.namespace)
logInput := *input
logInput.Type = g.mountType + "-request"
return g.c.auditBroker.LogRequest(ctx, &logInput, g.c.auditedHeaders)
}
func (g genericAuditor) AuditResponse(ctx context.Context, input *logical.LogInput) error {
ctx = namespace.ContextWithNamespace(ctx, g.namespace)
logInput := *input
logInput.Type = g.mountType + "-response"
return g.c.auditBroker.LogResponse(ctx, &logInput, g.c.auditedHeaders)
}