1
0
vault-redux/vault/rollback.go
hc-github-team-secure-vault-core d3028e6aa7
backport of commit 5f0638aa8bcb3e6188d4b92fabca29f15460203a (#25448) (#25552)
* prevent deadlock

* rollbacks not done for sync invalidate

* add check for the path before deleting

* revert sync invalidation doesn't do rollbacks

* add known issue

* changelog

* fix formatting issue

Co-authored-by: miagilepner <mia.epner@hashicorp.com>
2024-02-21 17:22:47 +01:00

443 lines
13 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package vault
import (
"context"
"errors"
"fmt"
"os"
"strconv"
"strings"
"sync"
"time"
metrics "github.com/armon/go-metrics"
"github.com/gammazero/workerpool"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/logical"
)
const (
RollbackDefaultNumWorkers = 256
RollbackWorkersEnvVar = "VAULT_ROLLBACK_WORKERS"
)
var rollbackCanceled = errors.New("rollback attempt canceled")
// RollbackManager is responsible for performing rollbacks of partial
// secrets within logical backends.
//
// During normal operations, it is possible for logical backends to
// error partially through an operation. These are called "partial secrets":
// they are never sent back to a user, but they do need to be cleaned up.
// This manager handles that by periodically (on a timer) requesting that the
// backends clean up.
//
// The RollbackManager periodically initiates a logical.RollbackOperation
// on every mounted logical backend. It ensures that only one rollback operation
// is in-flight at any given time within a single seal/unseal phase.
type RollbackManager struct {
logger log.Logger
// This gives the current mount table of both logical and credential backends,
// plus a RWMutex that is locked for reading. It is up to the caller to RUnlock
// it when done with the mount table.
backends func() []*MountEntry
router *Router
period time.Duration
inflightAll sync.WaitGroup
inflight map[string]*rollbackState
inflightLock sync.RWMutex
doneCh chan struct{}
shutdown bool
shutdownCh chan struct{}
shutdownLock sync.Mutex
stopTicker chan struct{}
tickerIsStopped bool
quitContext context.Context
runner *workerpool.WorkerPool
core *Core
// This channel is used for testing
rollbacksDoneCh chan struct{}
}
// rollbackState is used to track the state of a single rollback attempt
type rollbackState struct {
lastError error
sync.WaitGroup
cancelLockGrabCtx context.Context
cancelLockGrabCtxCancel context.CancelFunc
// scheduled is the time that this job was created and submitted to the
// rollbackRunner
scheduled time.Time
isRunning chan struct{}
isCanceled chan struct{}
}
// NewRollbackManager is used to create a new rollback manager
func NewRollbackManager(ctx context.Context, logger log.Logger, backendsFunc func() []*MountEntry, router *Router, core *Core) *RollbackManager {
r := &RollbackManager{
logger: logger,
backends: backendsFunc,
router: router,
period: core.rollbackPeriod,
inflight: make(map[string]*rollbackState),
doneCh: make(chan struct{}),
shutdownCh: make(chan struct{}),
stopTicker: make(chan struct{}),
quitContext: ctx,
core: core,
}
numWorkers := r.numRollbackWorkers()
r.logger.Info(fmt.Sprintf("Starting the rollback manager with %d workers", numWorkers))
r.runner = workerpool.New(numWorkers)
return r
}
func (m *RollbackManager) numRollbackWorkers() int {
numWorkers := m.core.numRollbackWorkers
envOverride := os.Getenv(RollbackWorkersEnvVar)
if envOverride != "" {
envVarWorkers, err := strconv.Atoi(envOverride)
if err != nil || envVarWorkers < 1 {
m.logger.Warn(fmt.Sprintf("%s must be a positive integer, but was %s", RollbackWorkersEnvVar, envOverride))
} else {
numWorkers = envVarWorkers
}
}
return numWorkers
}
// Start starts the rollback manager
func (m *RollbackManager) Start() {
go m.run()
}
// Stop stops the running manager. This will wait for any in-flight
// rollbacks to complete.
func (m *RollbackManager) Stop() {
m.shutdownLock.Lock()
defer m.shutdownLock.Unlock()
if !m.shutdown {
m.shutdown = true
close(m.shutdownCh)
<-m.doneCh
}
m.runner.StopWait()
}
// StopTicker stops the automatic Rollback manager's ticker, causing us
// to not do automatic rollbacks. This is useful for testing plugin's
// periodic function's behavior, without trying to race against the
// rollback manager proper.
//
// THIS SHOULD ONLY BE CALLED FROM TEST HELPERS.
func (m *RollbackManager) StopTicker() {
if !m.tickerIsStopped {
close(m.stopTicker)
m.tickerIsStopped = true
}
}
// run is a long running routine to periodically invoke rollback
func (m *RollbackManager) run() {
m.logger.Info("starting rollback manager")
tick := time.NewTicker(m.period)
logTestStopOnce := false
defer tick.Stop()
defer close(m.doneCh)
for {
select {
case <-tick.C:
m.triggerRollbacks()
case <-m.shutdownCh:
m.logger.Info("stopping rollback manager")
return
case <-m.stopTicker:
if !logTestStopOnce {
m.logger.Info("stopping rollback manager ticker for tests")
logTestStopOnce = true
}
tick.Stop()
}
}
}
// triggerRollbacks is used to trigger the rollbacks across all the backends
func (m *RollbackManager) triggerRollbacks() {
backends := m.backends()
for _, e := range backends {
path := e.Path
if e.Table == credentialTableType {
path = credentialRoutePrefix + path
}
// When the mount is filtered, the backend will be nil
ctx := namespace.ContextWithNamespace(m.quitContext, e.namespace)
backend := m.router.MatchingBackend(ctx, path)
if backend == nil {
continue
}
fullPath := e.namespace.Path + path
// Start a rollback if necessary
m.startOrLookupRollback(ctx, fullPath, true)
}
}
// lookupRollbackLocked checks if there's an inflight rollback with the given
// path. Callers must have the inflightLock. The function also reports metrics,
// since it is regularly called as part of the scheduled rollbacks.
func (m *RollbackManager) lookupRollbackLocked(fullPath string) *rollbackState {
defer metrics.SetGauge([]string{"rollback", "queued"}, float32(m.runner.WaitingQueueSize()))
defer metrics.SetGauge([]string{"rollback", "inflight"}, float32(len(m.inflight)))
rsInflight := m.inflight[fullPath]
return rsInflight
}
// newRollbackLocked creates a new rollback state and adds it to the inflight
// rollback map. Callers must have the inflightLock
func (m *RollbackManager) newRollbackLocked(fullPath string) *rollbackState {
cancelCtx, cancelFunc := context.WithCancel(context.Background())
rs := &rollbackState{
cancelLockGrabCtx: cancelCtx,
cancelLockGrabCtxCancel: cancelFunc,
isRunning: make(chan struct{}),
isCanceled: make(chan struct{}),
scheduled: time.Now(),
}
m.inflight[fullPath] = rs
rs.Add(1)
m.inflightAll.Add(1)
return rs
}
// startOrLookupRollback is used to start an async rollback attempt.
func (m *RollbackManager) startOrLookupRollback(ctx context.Context, fullPath string, grabStatelock bool) *rollbackState {
m.inflightLock.Lock()
defer m.inflightLock.Unlock()
rs := m.lookupRollbackLocked(fullPath)
if rs != nil {
return rs
}
// If no inflight rollback is already running, kick one off
rs = m.newRollbackLocked(fullPath)
select {
case <-m.doneCh:
// if we've already shut down, then don't submit the task to avoid a panic
// we should still call finishRollback for the rollback state in order to remove
// it from the map and decrement the waitgroup.
// we already have the inflight lock, so we can't grab it here
m.finishRollback(rs, errors.New("rollback manager is stopped"), fullPath, false)
default:
m.runner.Submit(func() {
m.attemptRollback(ctx, fullPath, rs, grabStatelock)
select {
case m.rollbacksDoneCh <- struct{}{}:
default:
}
})
}
return rs
}
func (m *RollbackManager) finishRollback(rs *rollbackState, err error, fullPath string, grabInflightLock bool) {
rs.lastError = err
rs.Done()
m.inflightAll.Done()
if grabInflightLock {
m.inflightLock.Lock()
defer m.inflightLock.Unlock()
}
if _, ok := m.inflight[fullPath]; ok {
delete(m.inflight, fullPath)
}
}
// attemptRollback invokes a RollbackOperation for the given path
func (m *RollbackManager) attemptRollback(ctx context.Context, fullPath string, rs *rollbackState, grabStatelock bool) (err error) {
close(rs.isRunning)
defer m.finishRollback(rs, err, fullPath, true)
select {
case <-rs.isCanceled:
return rollbackCanceled
default:
}
metrics.MeasureSince([]string{"rollback", "waiting"}, rs.scheduled)
defer metrics.MeasureSince([]string{"rollback", "attempt", strings.ReplaceAll(fullPath, "/", "-")}, time.Now())
ns, err := namespace.FromContext(ctx)
if err != nil {
m.logger.Error("rollback failed to derive namespace from context", "path", fullPath)
return err
}
if ns == nil {
m.logger.Error("rollback found no namespace", "path", fullPath)
return namespace.ErrNoNamespace
}
// Invoke a RollbackOperation
req := &logical.Request{
Operation: logical.RollbackOperation,
Path: ns.TrimmedPath(fullPath),
}
releaseLock := true
if grabStatelock {
doneCh := make(chan struct{})
defer close(doneCh)
stopCh := make(chan struct{})
go func() {
defer close(stopCh)
select {
case <-m.shutdownCh:
case <-rs.cancelLockGrabCtx.Done():
case <-doneCh:
case <-rs.isCanceled:
}
}()
// Grab the statelock or stop
l := newLockGrabber(m.core.stateLock.RLock, m.core.stateLock.RUnlock, stopCh)
go l.grab()
if stopped := l.lockOrStop(); stopped {
// If we stopped due to shutdown, return. Otherwise another thread
// is holding the lock for us, continue on.
select {
case <-m.shutdownCh:
return errors.New("rollback shutting down")
case <-rs.isCanceled:
return rollbackCanceled
default:
releaseLock = false
}
}
}
var cancelFunc context.CancelFunc
ctx, cancelFunc = context.WithTimeout(ctx, DefaultMaxRequestDuration)
resp, err := m.router.Route(ctx, req)
if grabStatelock && releaseLock {
m.core.stateLock.RUnlock()
}
cancelFunc()
// If the error is an unsupported operation, then it doesn't
// matter, the backend doesn't support it.
if err == logical.ErrUnsupportedOperation {
err = nil
}
// If we failed due to read-only storage, we can't do anything; ignore
if (err != nil && strings.Contains(err.Error(), logical.ErrReadOnly.Error())) ||
(resp.IsError() && strings.Contains(resp.Error().Error(), logical.ErrReadOnly.Error())) {
err = nil
}
if err != nil {
m.logger.Error("error rolling back", "path", fullPath, "error", err)
}
return
}
// Rollback is used to trigger an immediate rollback of the path,
// or to join an existing rollback operation if in flight. Caller should have
// core's statelock held (write OR read). If an already inflight rollback is
// happening this function will simply wait for it to complete
func (m *RollbackManager) Rollback(ctx context.Context, path string) error {
ns, err := namespace.FromContext(ctx)
if err != nil {
return err
}
fullPath := ns.Path + path
m.inflightLock.Lock()
rs := m.lookupRollbackLocked(fullPath)
if rs != nil {
// Since we have the statelock held, tell any inflight rollback to give up
// trying to acquire it. This will prevent deadlocks in the case where we
// have the write lock. In the case where it was waiting to grab
// a read lock it will then simply continue with the rollback
// operation under the protection of our write lock.
rs.cancelLockGrabCtxCancel()
select {
case <-rs.isRunning:
// if the rollback has started then we should wait for it to complete
m.inflightLock.Unlock()
rs.Wait()
return rs.lastError
default:
}
// if the rollback hasn't started and there's no capacity, we could
// end up deadlocking. Cancel the existing rollback and start a new
// one.
close(rs.isCanceled)
}
rs = m.newRollbackLocked(fullPath)
m.inflightLock.Unlock()
// we can ignore the error, since it's going to be set in rs.lastError
m.attemptRollback(ctx, fullPath, rs, false)
rs.Wait()
return rs.lastError
}
// The methods below are the hooks from core that are called pre/post seal.
// startRollback is used to start the rollback manager after unsealing
func (c *Core) startRollback() error {
backendsFunc := func() []*MountEntry {
ret := []*MountEntry{}
c.mountsLock.RLock()
defer c.mountsLock.RUnlock()
// During teardown/setup after a leader change or unseal there could be
// something racy here so make sure the table isn't nil
if c.mounts != nil {
for _, entry := range c.mounts.Entries {
ret = append(ret, entry)
}
}
c.authLock.RLock()
defer c.authLock.RUnlock()
// During teardown/setup after a leader change or unseal there could be
// something racy here so make sure the table isn't nil
if c.auth != nil {
for _, entry := range c.auth.Entries {
ret = append(ret, entry)
}
}
return ret
}
rollbackLogger := c.baseLogger.Named("rollback")
c.AddLogger(rollbackLogger)
c.rollback = NewRollbackManager(c.activeContext, rollbackLogger, backendsFunc, c.router, c)
c.rollback.Start()
return nil
}
// stopRollback is used to stop running the rollback manager before sealing
func (c *Core) stopRollback() error {
if c.rollback != nil {
c.rollback.Stop()
c.rollback = nil
}
return nil
}