// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 package vault import ( "bytes" "context" "encoding/base64" "fmt" "image/png" "net/http" "strings" "sync" "time" "github.com/golang/protobuf/proto" "github.com/hashicorp/errwrap" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-memdb" "github.com/hashicorp/go-multierror" "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/helper/identity" "github.com/hashicorp/vault/helper/identity/mfa" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/hashicorp/vault/sdk/helper/strutil" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault/quotas" "github.com/mitchellh/mapstructure" "github.com/patrickmn/go-cache" otplib "github.com/pquerna/otp" totplib "github.com/pquerna/otp/totp" ) const ( mfaMethodTypeTOTP = "totp" memDBLoginMFAConfigsTable = "login_mfa_configs" memDBMFALoginEnforcementsTable = "login_enforcements" mfaTOTPKeysPrefix = systemBarrierPrefix + "mfa/totpkeys/" // loginMFAConfigPrefix is the storage prefix for persisting login MFA method // configs loginMFAConfigPrefix = "login-mfa/method/" mfaLoginEnforcementPrefix = "login-mfa/enforcement/" ) type totpKey struct { Key string `json:"key"` } // loginMfaPaths returns the API endpoints to configure the new style // login MFA. The following paths are supported: // mfa/method/:mfa_method - management of MFA method IDs, which can be used for configuration // mfa/login_enforcement/:config_name - configures single or two phase MFA auth func (b *SystemBackend) loginMFAPaths() []*framework.Path { return []*framework.Path{ { Pattern: "mfa/validate", DisplayAttrs: &framework.DisplayAttributes{ OperationPrefix: "mfa", OperationVerb: "validate", }, Fields: map[string]*framework.FieldSchema{ "mfa_request_id": { Type: framework.TypeString, Description: "ID for this MFA request", Required: true, }, "mfa_payload": { Type: framework.TypeMap, Description: "A map from MFA method ID to a slice of passcodes or an empty slice if the method does not use passcodes", Required: true, }, }, Operations: map[logical.Operation]framework.OperationHandler{ logical.UpdateOperation: &framework.PathOperation{ Callback: b.Core.loginMFABackend.handleMFALoginValidate, Responses: map[int][]framework.Response{ http.StatusOK: {{ Description: "OK", }}, }, Summary: "Validates the login for the given MFA methods. Upon successful validation, it returns an auth response containing the client token", ForwardPerformanceStandby: true, }, }, }, } } func genericOptionalUUIDRegex(name string) string { return fmt.Sprintf("(/(?P<%s>[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}))?", name) } type MFABackend struct { Core *Core mfaLock *sync.RWMutex db *memdb.MemDB mfaLogger hclog.Logger namespacer Namespacer methodTable string usedCodes *cache.Cache } type LoginMFABackend struct { *MFABackend } func loginMFASchemaFuncs() []func() *memdb.TableSchema { return []func() *memdb.TableSchema{ loginMFAConfigTableSchema, loginEnforcementTableSchema, } } func NewLoginMFABackend(core *Core, logger hclog.Logger) *LoginMFABackend { b := NewMFABackend(core, logger, memDBLoginMFAConfigsTable, loginMFASchemaFuncs()) return &LoginMFABackend{b} } func NewMFABackend(core *Core, logger hclog.Logger, prefix string, schemaFuncs []func() *memdb.TableSchema) *MFABackend { db, _ := SetupMFAMemDB(schemaFuncs) return &MFABackend{ Core: core, mfaLock: &sync.RWMutex{}, db: db, mfaLogger: logger.Named("mfa"), namespacer: core, methodTable: prefix, } } func SetupMFAMemDB(schemaFuncs []func() *memdb.TableSchema) (*memdb.MemDB, error) { mfaSchemas := &memdb.DBSchema{ Tables: make(map[string]*memdb.TableSchema), } for _, schemaFunc := range schemaFuncs { schema := schemaFunc() if _, ok := mfaSchemas.Tables[schema.Name]; ok { panic(fmt.Sprintf("duplicate table name: %s", schema.Name)) } mfaSchemas.Tables[schema.Name] = schema } db, err := memdb.NewMemDB(mfaSchemas) if err != nil { return nil, err } return db, nil } func (b *LoginMFABackend) ResetLoginMFAMemDB() error { var err error db, err := SetupMFAMemDB(loginMFASchemaFuncs()) if err != nil { return err } b.db = db return nil } func (i *IdentityStore) handleMFAMethodListTOTP(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { return i.handleMFAMethodList(ctx, req, d, mfaMethodTypeTOTP) } func (i *IdentityStore) handleMFAMethodListGlobal(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { keys, configInfo, err := i.mfaBackend.mfaMethodList(ctx, "") if err != nil { return nil, err } return logical.ListResponseWithInfo(keys, configInfo), nil } func (i *IdentityStore) handleMFAMethodList(ctx context.Context, req *logical.Request, d *framework.FieldData, methodType string) (*logical.Response, error) { keys, configInfo, err := i.mfaBackend.mfaMethodList(ctx, methodType) if err != nil { return nil, err } return logical.ListResponseWithInfo(keys, configInfo), nil } func (i *IdentityStore) handleMFAMethodTOTPRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { return i.handleMFAMethodReadCommon(ctx, req, d, mfaMethodTypeTOTP) } func (i *IdentityStore) handleMFAMethodReadGlobal(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { return i.handleMFAMethodReadCommon(ctx, req, d, "") } func (i *IdentityStore) handleMFAMethodReadCommon(ctx context.Context, req *logical.Request, d *framework.FieldData, methodType string) (*logical.Response, error) { methodID := d.Get("method_id").(string) if methodID == "" { return logical.ErrorResponse("missing method ID"), nil } ns, err := namespace.FromContext(ctx) if err != nil { return nil, err } respData, err := i.mfaBackend.mfaConfigReadByMethodID(methodID) if err != nil { return nil, err } if respData == nil { return nil, nil } mfaNs, err := i.namespacer.NamespaceByID(ctx, respData["namespace_id"].(string)) if err != nil { return nil, err } // reading the method config either from the same namespace or from the parent or from the child should all work if !(ns.ID == mfaNs.ID || mfaNs.HasParent(ns) || ns.HasParent(mfaNs)) { return logical.ErrorResponse("request namespace does not match method namespace"), logical.ErrPermissionDenied } if methodType != "" && respData["type"] != methodType { return logical.ErrorResponse("failed to find the method ID under MFA type %s.", methodType), nil } return &logical.Response{ Data: respData, }, nil } func (i *IdentityStore) handleMFAMethodUpdateCommon(ctx context.Context, req *logical.Request, d *framework.FieldData, methodType string) (*logical.Response, error) { var err error var mConfig *mfa.Config ns, err := namespace.FromContext(ctx) if err != nil { return nil, err } methodID := d.Get("method_id").(string) methodName := d.Get("method_name").(string) b := i.mfaBackend b.mfaLock.Lock() defer b.mfaLock.Unlock() if methodID != "" { mConfig, err = b.MemDBMFAConfigByID(methodID) if err != nil { return nil, err } // If methodID is specified, but we didn't find anything, return a 404 if mConfig == nil { return nil, nil } } // check if an MFA method configuration exists with that method name if methodName != "" { namedMfaConfig, err := b.MemDBMFAConfigByName(ctx, methodName) if err != nil { return nil, err } if namedMfaConfig != nil { if mConfig == nil { mConfig = namedMfaConfig } else { if mConfig.ID != namedMfaConfig.ID { return nil, fmt.Errorf("a login MFA method configuration with the method name %s already exists", methodName) } } } } if mConfig == nil { configID, err := uuid.GenerateUUID() if err != nil { return nil, fmt.Errorf("failed to generate an identifier for MFA config: %v", err) } mConfig = &mfa.Config{ ID: configID, Type: methodType, NamespaceID: ns.ID, } } // Updating the method config name if methodName != "" { mConfig.Name = methodName } mfaNs, err := i.namespacer.NamespaceByID(ctx, mConfig.NamespaceID) if err != nil { return nil, err } // this logic assumes that the config namespace and the current // namespace should be the same. Note an ancestor of mfaNs is not allowed // to create/update methodID if ns.ID != mfaNs.ID { return logical.ErrorResponse("request namespace does not match method namespace"), nil } mConfig.Type = methodType usernameRaw, ok := d.GetOk("username_format") if ok { mConfig.UsernameFormat = usernameRaw.(string) } switch methodType { case mfaMethodTypeTOTP: err = parseTOTPConfig(mConfig, d) if err != nil { return logical.ErrorResponse(err.Error()), nil } default: return logical.ErrorResponse(fmt.Sprintf("unrecognized type %q", methodType)), nil } // Store the config err = b.putMFAConfigByID(ctx, mConfig) if err != nil { return nil, err } // Back the config in MemDB err = b.MemDBUpsertMFAConfig(ctx, mConfig) if err != nil { return nil, err } if methodID == "" { return &logical.Response{ Data: map[string]interface{}{ "method_id": mConfig.ID, }, }, nil } else { return nil, nil } } func (i *IdentityStore) handleMFAMethodTOTPUpdate(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { return i.handleMFAMethodUpdateCommon(ctx, req, d, mfaMethodTypeTOTP) } func (i *IdentityStore) handleMFAMethodTOTPDelete(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { return i.handleMFAMethodDeleteCommon(ctx, req, d, mfaMethodTypeTOTP) } func (i *IdentityStore) handleMFAMethodDeleteCommon(ctx context.Context, req *logical.Request, d *framework.FieldData, methodType string) (*logical.Response, error) { methodID := d.Get("method_id").(string) if methodID == "" { return logical.ErrorResponse("missing method ID"), nil } return nil, i.mfaBackend.deleteMFAConfigByMethodID(ctx, methodID, methodType, memDBLoginMFAConfigsTable, loginMFAConfigPrefix) } func (i *IdentityStore) handleLoginMFAGenerateUpdate(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { return i.handleLoginMFAGenerateCommon(ctx, req, d.Get("method_id").(string), req.EntityID) } func (i *IdentityStore) handleLoginMFAAdminGenerateUpdate(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { return i.handleLoginMFAGenerateCommon(ctx, req, d.Get("method_id").(string), d.Get("entity_id").(string)) } func (i *IdentityStore) handleLoginMFAGenerateCommon(ctx context.Context, req *logical.Request, methodID, entityID string) (*logical.Response, error) { if methodID == "" { return logical.ErrorResponse("missing method ID"), nil } if entityID == "" { return logical.ErrorResponse("missing entityID"), nil } mConfig, err := i.mfaBackend.MemDBMFAConfigByID(methodID) if err != nil { return nil, err } if mConfig == nil { return logical.ErrorResponse(fmt.Sprintf("configuration for method ID %q does not exist", methodID)), nil } if mConfig.ID == "" { return nil, fmt.Errorf("configuration for method ID %q does not contain an identifier", methodID) } entity, err := i.MemDBEntityByID(entityID, true) if err != nil { return nil, fmt.Errorf("failed to find entity with ID %q: error: %w", entityID, err) } if entity == nil { return logical.ErrorResponse("invalid entity ID"), nil } ns, err := namespace.FromContext(ctx) if err != nil { return logical.ErrorResponse("failed to retrieve the namespace"), nil } if ns.ID != entity.NamespaceID { return logical.ErrorResponse("entity namespace ID does not match the current namespace ID"), nil } entityNS, err := i.namespacer.NamespaceByID(ctx, entity.NamespaceID) if err != nil { return logical.ErrorResponse("entity namespace not found"), nil } configNS, err := i.namespacer.NamespaceByID(ctx, mConfig.NamespaceID) if err != nil { return logical.ErrorResponse("methodID namespace not found"), nil } if configNS.ID != entityNS.ID && !entityNS.HasParent(configNS) { return logical.ErrorResponse(fmt.Sprintf("entity namespace %s outside of the config namespace %s", entityNS.Path, configNS.Path)), nil } switch mConfig.Type { case mfaMethodTypeTOTP: return i.mfaBackend.handleMFAGenerateTOTP(ctx, mConfig, entityID) default: return logical.ErrorResponse(fmt.Sprintf("generate not available for MFA type %q", mConfig.Type)), nil } } func (i *IdentityStore) handleLoginMFAAdminDestroyUpdate(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { var entity *identity.Entity var err error methodID := d.Get("method_id").(string) if methodID == "" { return logical.ErrorResponse("missing method ID"), nil } entityID := d.Get("entity_id").(string) if entityID == "" { return logical.ErrorResponse("missing entity ID"), nil } entity, err = i.MemDBEntityByID(entityID, true) if err != nil { return nil, fmt.Errorf("failed to find entity with ID %q: error: %w", entityID, err) } if entity == nil { return logical.ErrorResponse("invalid entity ID"), nil } mConfig, err := i.mfaBackend.MemDBMFAConfigByID(methodID) if err != nil { return nil, err } if mConfig == nil { return logical.ErrorResponse(fmt.Sprintf("configuration for method ID %q does not exist", methodID)), nil } if mConfig.ID == "" { return nil, fmt.Errorf("configuration for method ID %q does not contain an identifier", methodID) } if mConfig.Type != mfaMethodTypeTOTP { return nil, fmt.Errorf("method ID does not match TOTP type") } ns, err := namespace.FromContext(ctx) if err != nil { return logical.ErrorResponse("failed to retrieve the namespace"), nil } if ns.ID != entity.NamespaceID { return logical.ErrorResponse("entity namespace ID does not match the current namespace ID"), nil } entityNS, err := i.namespacer.NamespaceByID(ctx, entity.NamespaceID) if err != nil { return logical.ErrorResponse("entity namespace not found"), nil } configNS, err := i.namespacer.NamespaceByID(ctx, mConfig.NamespaceID) if err != nil { return logical.ErrorResponse("methodID namespace not found"), nil } if configNS.ID != entityNS.ID && !entityNS.HasParent(configNS) { return logical.ErrorResponse(fmt.Sprintf("entity namespace %s outside of the current namespace %s", entityNS.Path, ns.Path)), nil } // destroying the secret on the entity if entity.MFASecrets != nil { delete(entity.MFASecrets, mConfig.ID) } err = i.upsertEntity(ctx, entity, nil, true) if err != nil { return nil, fmt.Errorf("failed to persist MFA secret in entity, error: %w", err) } return nil, nil } // loadMFAMethodConfigs loads MFA method configs for login MFA func (b *LoginMFABackend) loadMFAMethodConfigs(ctx context.Context, ns *namespace.Namespace) error { b.mfaLogger.Trace("loading login MFA configurations") barrierView, err := b.Core.barrierViewForNamespace(ns.ID) if err != nil { return fmt.Errorf("error getting namespace view, namespaceid %s, error %w", ns.ID, err) } existing, err := barrierView.List(ctx, loginMFAConfigPrefix) if err != nil { return fmt.Errorf("failed to list MFA configurations for namespace path %s and prefix %s: %w", ns.Path, loginMFAConfigPrefix, err) } b.mfaLogger.Trace("methods collected", "num_existing", len(existing)) for _, key := range existing { b.mfaLogger.Trace("loading method", "method", key) // Read the config from storage mConfig, err := b.getMFAConfig(ctx, loginMFAConfigPrefix+key, barrierView) if err != nil { return err } if mConfig == nil { b.mfaLogger.Trace("failed to find the config related to a method", "namespace", ns.Path, "prefix", loginMFAConfigPrefix, "method", key) continue } // Load the config in MemDB err = b.MemDBUpsertMFAConfig(ctx, mConfig) if err != nil { return fmt.Errorf("failed to load configuration ID %s prefix %s in MemDB: %w", mConfig.ID, loginMFAConfigPrefix, err) } } b.mfaLogger.Trace("configurations restored", "namespace", ns.Path, "prefix", loginMFAConfigPrefix) return nil } // loadMFAEnforcementConfigs loads MFA method configs for login MFA func (b *LoginMFABackend) loadMFAEnforcementConfigs(ctx context.Context, ns *namespace.Namespace) ([]*mfa.MFAEnforcementConfig, error) { b.mfaLogger.Trace("loading login MFA enforcement configurations") barrierView, err := b.Core.barrierViewForNamespace(ns.ID) if err != nil { return nil, fmt.Errorf("error getting namespace view, namespaceid %s, error %w", ns.ID, err) } existing, err := barrierView.List(ctx, mfaLoginEnforcementPrefix) if err != nil { return nil, fmt.Errorf("failed to list MFA enforcement configurations for namespace %s with prefix %s: %w", ns.Path, mfaLoginEnforcementPrefix, err) } b.mfaLogger.Trace("enforcements configs collected", "num_existing", len(existing)) eConfigs := make([]*mfa.MFAEnforcementConfig, 0) for _, key := range existing { b.mfaLogger.Trace("loading enforcement", "config", key) // Read the config from storage mConfig, err := b.getMFALoginEnforcementConfig(ctx, mfaLoginEnforcementPrefix+key, barrierView) if err != nil { return nil, err } if mConfig == nil { b.mfaLogger.Trace("failed to find an enforcement config", "namespace", ns.Path, "prefix", mfaLoginEnforcementPrefix, "config", key) continue } // Load the config in MemDB err = b.MemDBUpsertMFALoginEnforcementConfig(ctx, mConfig) if err != nil { return nil, fmt.Errorf("failed to load enforcement configuration ID %s with prefix %s in MemDB: %w", mConfig.ID, mfaLoginEnforcementPrefix, err) } eConfigs = append(eConfigs, mConfig) } b.mfaLogger.Trace("enforcement configurations restored", "namespace", ns.Path, "prefix", mfaLoginEnforcementPrefix) return eConfigs, nil } func (b *LoginMFABackend) loginMFAMethodExistenceCheck(eConfig *mfa.MFAEnforcementConfig) error { var aggErr *multierror.Error for _, confID := range eConfig.MFAMethodIDs { config, memErr := b.MemDBMFAConfigByID(confID) if memErr != nil { aggErr = multierror.Append(aggErr, memErr) return aggErr.ErrorOrNil() } if config == nil { aggErr = multierror.Append(aggErr, fmt.Errorf("found an MFA method ID in enforcement config, but failed to find the MFA method config method ID %s", confID)) } } return aggErr.ErrorOrNil() } // sanitizeMFACredsWithLoginEnforcementMethodIDs updates the MFACred map // looping through the matched login enforcement configurations, and // replacing MFA method names with MFA method IDs func (b *LoginMFABackend) sanitizeMFACredsWithLoginEnforcementMethodIDs(ctx context.Context, mfaCredsMap logical.MFACreds, mfaMethodIDs []string) (logical.MFACreds, error) { sanitizedMfaCreds := make(logical.MFACreds, 0) var multiError *multierror.Error for _, methodID := range mfaMethodIDs { val, ok := mfaCredsMap[methodID] if ok { sanitizedMfaCreds[methodID] = val continue } mConfig, err := b.MemDBMFAConfigByID(methodID) if err != nil { return nil, err } if mConfig == nil { multiError = multierror.Append(multiError, fmt.Errorf("failed to find MFA config for method ID %s", methodID)) continue } // method name in the MFACredsMap should be the method full name, // i.e., namespacePath+name. This is because, a user in a child // namespace can reference an MFA method ID in a parent namespace configNS, err := NamespaceByID(ctx, mConfig.NamespaceID, b.Core) if err != nil { return nil, err } if configNS != nil { val, ok = mfaCredsMap[configNS.Path+mConfig.Name] if ok { sanitizedMfaCreds[mConfig.ID] = val } else { multiError = multierror.Append(multiError, fmt.Errorf("failed to find MFA credentials associated with an MFA method ID %v, method name %v", methodID, configNS.Path+mConfig.Name)) } } else { multiError = multierror.Append(multiError, fmt.Errorf("failed to find the namespace associated with an MFA method ID %v", mConfig.ID)) } } // we don't need to find every MFA method identifiers in the MFA header // So, don't return errors if that is the case. if len(sanitizedMfaCreds) > 0 { return sanitizedMfaCreds, nil } return sanitizedMfaCreds, multiError } func (b *LoginMFABackend) handleMFALoginValidate(ctx context.Context, req *logical.Request, d *framework.FieldData) (retResp *logical.Response, retErr error) { // mfaReqID is the ID of the login request mfaReqID := d.Get("mfa_request_id").(string) if mfaReqID == "" { return logical.ErrorResponse("missing request ID"), nil } // a map of methodID to passcode mfaPayload := d.Get("mfa_payload") if mfaPayload == nil { return logical.ErrorResponse("missing mfa payload"), nil } var mfaCreds logical.MFACreds err := mapstructure.Decode(mfaPayload, &mfaCreds) if err != nil { return logical.ErrorResponse("invalid mfa payload"), nil } // getting the cached response Auth. We should note that the entry is // removed from the queue, and if any error happens before the validation // and creating a token succeed, we need to push the entry back to the queue. cachedResponseAuth, err := b.Core.PopMFAResponseAuthByID(mfaReqID) if err != nil || cachedResponseAuth == nil { return logical.ErrorResponse("invalid request ID"), nil } defer func() { // Only if retErr is NOT nil, then push back the valid entry if retErr == nil { return } pushErr := b.Core.SaveMFAResponseAuth(cachedResponseAuth) if pushErr != nil { retErr = multierror.Append(retErr, pushErr) } }() ns, err := namespace.FromContext(ctx) if err != nil { return nil, fmt.Errorf("MFA validation failed. Namespace not found. error: %v", err) } if ns.ID != cachedResponseAuth.RequestNSID { return nil, fmt.Errorf("original request was issued in a different namesapce %v, current namespace is %v", cachedResponseAuth.RequestNSPath, ns.Path) } entity, _, err := b.Core.fetchEntityAndDerivedPolicies(ctx, ns, cachedResponseAuth.CachedAuth.EntityID, true) if err != nil || entity == nil { return nil, fmt.Errorf("MFA validation failed. entity not found: %v", err) } // finding the MFAEnforcement config that matches our ns. ns could be root as well matchedMfaEnforcementList, err := b.Core.buildMFAEnforcementConfigList(ctx, entity, cachedResponseAuth.RequestPath) if err != nil { return nil, fmt.Errorf("failed to find MFAEnforcement configuration") } if len(matchedMfaEnforcementList) == 0 { return nil, fmt.Errorf("found nil or empty MFAEnforcement configuration") } for _, eConfig := range matchedMfaEnforcementList { err = b.Core.validateLoginMFA(ctx, eConfig, entity, req.Connection.RemoteAddr, mfaCreds) if err != nil { return logical.ErrorResponse(fmt.Sprintf("failed to satisfy enforcement %s. error: %s", eConfig.Name, err.Error())), logical.ErrPermissionDenied } } // MFA validation has passed. Let's generate the token resp, err := b.Core.LoginMFACreateToken(ctx, cachedResponseAuth.RequestPath, cachedResponseAuth.CachedAuth, req.Data) if err != nil { return nil, fmt.Errorf("failed to create a token. error: %v", err) } return resp, nil } func (c *Core) teardownLoginMFA() error { if !c.IsDRSecondary() { // Clear any cached auth response c.mfaResponseAuthQueueLock.Lock() c.mfaResponseAuthQueue = nil c.mfaResponseAuthQueueLock.Unlock() c.loginMFABackend.usedCodes = nil if err := c.loginMFABackend.ResetLoginMFAMemDB(); err != nil { return err } } return nil } // LoginMFACreateToken creates a token after the login MFA is validated. // It also applies the lease quotas on the original login request path. func (c *Core) LoginMFACreateToken(ctx context.Context, reqPath string, cachedAuth *logical.Auth, loginRequestData map[string]interface{}) (*logical.Response, error) { auth := cachedAuth resp := &logical.Response{ Auth: auth, } // Determine the source of the login mountPoint := c.router.MatchingMount(ctx, reqPath) ns, err := namespace.FromContext(ctx) if err != nil { return nil, fmt.Errorf("namespace not found: %w", err) } var role string if reqRole := ctx.Value(logical.CtxKeyRequestRole{}); reqRole != nil { role = reqRole.(string) } // The request successfully authenticated itself. Run the quota checks on // the original login request path before creating the token. quotaResp, quotaErr := c.applyLeaseCountQuota(ctx, "as.Request{ Path: reqPath, MountPath: strings.TrimPrefix(mountPoint, ns.Path), Role: role, NamespacePath: ns.Path, }) if quotaErr != nil { c.logger.Error("failed to apply quota", "path", reqPath, "error", quotaErr) return nil, quotaErr } if !quotaResp.Allowed { if c.logger.IsTrace() { c.logger.Trace("request rejected due to lease count quota violation", "request_path", reqPath) } return nil, fmt.Errorf("request path %q: %w", reqPath, quotas.ErrLeaseCountQuotaExceeded) } // note that we don't need to handle the error for the following function right away. // The function takes the response as in input variable and modify it. So, the returned // arguments are resp and err. leaseGenerated, resp, err := c.LoginCreateToken(ctx, ns, reqPath, mountPoint, role, resp) if quotaResp.Access != nil { quotaAckErr := c.ackLeaseQuota(quotaResp.Access, leaseGenerated) if quotaAckErr != nil { err = multierror.Append(err, quotaAckErr) } } return resp, err } func (i *IdentityStore) handleMFALoginEnforcementList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { keys, configInfo, err := i.mfaBackend.mfaLoginEnforcementList(ctx) if err != nil { return nil, err } return logical.ListResponseWithInfo(keys, configInfo), nil } func (i *IdentityStore) handleMFALoginEnforcementRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) ns, err := namespace.FromContext(ctx) if err != nil { return nil, err } respData, err := i.mfaBackend.mfaLoginEnforcementConfigByNameAndNamespace(name, ns.ID) if err != nil { return nil, err } if respData == nil { return nil, nil } // The config is readable only from the same namespace if ns.ID != respData["namespace_id"].(string) { return logical.ErrorResponse("request namespace does not match method namespace"), nil } return &logical.Response{ Data: respData, }, nil } func (i *IdentityStore) handleMFALoginEnforcementUpdate(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { var err error var eConfig *mfa.MFAEnforcementConfig ns, err := namespace.FromContext(ctx) if err != nil { return nil, err } name := d.Get("name").(string) if name == "" { return logical.ErrorResponse("missing enforcement name"), nil } b := i.mfaBackend b.mfaLock.Lock() defer b.mfaLock.Unlock() eConfig, err = b.MemDBMFALoginEnforcementConfigByNameAndNamespace(name, ns.ID) if err != nil { return nil, err } if eConfig == nil { configID, err := uuid.GenerateUUID() if err != nil { return nil, fmt.Errorf("failed to generate an identifier for MFA login enforcement config: %w", err) } eConfig = &mfa.MFAEnforcementConfig{ Name: name, NamespaceID: ns.ID, ID: configID, } } mfaMethodIds, ok := d.GetOk("mfa_method_ids") if !ok { return logical.ErrorResponse("missing method ids"), nil } for _, mmid := range mfaMethodIds.([]string) { // make sure this method id actually exists config, err := b.mfaConfigReadByMethodID(mmid) if err != nil { return nil, err } if config == nil { return logical.ErrorResponse("one of the provided method ids doesn't exist"), nil } mfaNs, err := i.namespacer.NamespaceByID(ctx, config["namespace_id"].(string)) if err != nil { return logical.ErrorResponse("failed to retrieve config namespace"), nil } if ns.ID != mfaNs.ID && !ns.HasParent(mfaNs) { return logical.ErrorResponse("one of the provided method ids is in an incompatible namespace and can't be used"), nil } } eConfig.MFAMethodIDs = mfaMethodIds.([]string) oneOfLastFour := false authMethodAccessors, ok := d.GetOk("auth_method_accessors") if ok { for _, accessor := range authMethodAccessors.([]string) { found, err := b.validateAuthEntriesForAccessorOrType(ctx, ns, func(entry *MountEntry) bool { return accessor == entry.Accessor }) if err != nil { return nil, err } if !found { return logical.ErrorResponse("one of the auth method accessors provided is invalid"), nil } } eConfig.AuthMethodAccessors = authMethodAccessors.([]string) oneOfLastFour = true } authMethodTypes, ok := d.GetOk("auth_method_types") if ok { for _, authType := range authMethodTypes.([]string) { found, err := b.validateAuthEntriesForAccessorOrType(ctx, ns, func(entry *MountEntry) bool { return authType == entry.Type }) if err != nil { return nil, err } if !found { return logical.ErrorResponse("one of the auth method types provided is invalid"), nil } } eConfig.AuthMethodTypes = authMethodTypes.([]string) oneOfLastFour = true } identityGroupIds, ok := d.GetOk("identity_group_ids") if ok { for _, groupId := range identityGroupIds.([]string) { group, err := i.MemDBGroupByID(groupId, true) if err != nil { return nil, err } if group == nil { return logical.ErrorResponse("one of the provided group ids doesn't exist"), nil } } eConfig.IdentityGroupIds = identityGroupIds.([]string) oneOfLastFour = true } identityEntityIds, ok := d.GetOk("identity_entity_ids") if ok { for _, entityId := range identityEntityIds.([]string) { entity, err := i.MemDBEntityByID(entityId, true) if err != nil { return nil, err } if entity == nil { return logical.ErrorResponse("one of the provided entity ids doesn't exist"), nil } } eConfig.IdentityEntityIDs = identityEntityIds.([]string) oneOfLastFour = true } if !oneOfLastFour { return logical.ErrorResponse("One of auth_method_accessors, auth_method_types, identity_group_ids, identity_entity_ids must be specified"), nil } // Store the config err = b.putMFALoginEnforcementConfig(ctx, eConfig) if err != nil { return nil, err } // Back the config in MemDB return nil, b.MemDBUpsertMFALoginEnforcementConfig(ctx, eConfig) } func (i *IdentityStore) handleMFALoginEnforcementDelete(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) ns, err := namespace.FromContext(ctx) if err != nil { return nil, err } return nil, i.mfaBackend.deleteMFALoginEnforcementConfigByNameAndNamespace(ctx, name, ns.ID) } func (b *LoginMFABackend) validateAuthEntriesForAccessorOrType(ctx context.Context, ns *namespace.Namespace, validFunc func(entry *MountEntry) bool) (bool, error) { b.Core.authLock.RLock() defer b.Core.authLock.RUnlock() for _, entry := range b.Core.auth.Entries { // only check auth methods in the current namespace if entry.Namespace().ID != ns.ID { continue } cont, err := b.Core.checkReplicatedFiltering(ctx, entry, credentialRoutePrefix) if err != nil { return false, err } if cont { continue } if validFunc(entry) { return true, nil } } return false, nil } func (c *Core) PersistTOTPKey(ctx context.Context, methodID, entityID, key string) error { ks := &totpKey{ Key: key, } val, err := jsonutil.EncodeJSON(ks) if err != nil { return err } if c.barrier.Put(ctx, &logical.StorageEntry{ Key: fmt.Sprintf("%s%s/%s", mfaTOTPKeysPrefix, methodID, entityID), Value: val, }); err != nil { return err } return nil } func (c *Core) fetchTOTPKey(ctx context.Context, methodID, entityID string) (string, error) { entry, err := c.barrier.Get(ctx, fmt.Sprintf("%s%s/%s", mfaTOTPKeysPrefix, methodID, entityID)) if err != nil { return "", err } if entry == nil { return "", nil } ks := &totpKey{} err = jsonutil.DecodeJSON(entry.Value, ks) if err != nil { return "", err } return ks.Key, nil } func (b *MFABackend) handleMFAGenerateTOTP(ctx context.Context, mConfig *mfa.Config, entityID string) (*logical.Response, error) { var err error var totpConfig *mfa.TOTPConfig if b.Core.identityStore == nil { return nil, fmt.Errorf("identity store not set up, cannot service totp mfa requests") } switch mConfig.Config.(type) { case *mfa.Config_TOTPConfig: totpConfig = mConfig.Config.(*mfa.Config_TOTPConfig).TOTPConfig default: return logical.ErrorResponse(fmt.Sprintf("unknown MFA config type %q", mConfig.Type)), nil } b.Core.identityStore.lock.Lock() defer b.Core.identityStore.lock.Unlock() // Read the entity after acquiring the lock entity, err := b.Core.identityStore.MemDBEntityByID(entityID, true) if err != nil { return nil, errwrap.Wrapf(fmt.Sprintf("failed to find entity with ID %q: {{err}}", entityID), err) } if entity == nil { return logical.ErrorResponse("invalid entity ID"), nil } if entity.MFASecrets == nil { entity.MFASecrets = make(map[string]*mfa.Secret) } else { _, ok := entity.MFASecrets[mConfig.ID] if ok { resp := &logical.Response{} resp.AddWarning(fmt.Sprintf("Entity already has a secret for MFA method %q", mConfig.Name)) return resp, nil } } keyObject, err := totplib.Generate(totplib.GenerateOpts{ Issuer: totpConfig.Issuer, AccountName: entity.ID, Period: uint(totpConfig.Period), Digits: otplib.Digits(totpConfig.Digits), Algorithm: otplib.Algorithm(totpConfig.Algorithm), SecretSize: uint(totpConfig.KeySize), Rand: b.Core.secureRandomReader, }) if err != nil { return nil, errwrap.Wrapf(fmt.Sprintf("failed to generate TOTP key for method name %q: {{err}}", mConfig.Name), err) } if keyObject == nil { return nil, fmt.Errorf("failed to generate TOTP key for method name %q", mConfig.Name) } totpURL := keyObject.String() totpB64Barcode := "" if totpConfig.QRSize != 0 { barcode, err := keyObject.Image(int(totpConfig.QRSize), int(totpConfig.QRSize)) if err != nil { return nil, errwrap.Wrapf("failed to generate QR code image: {{err}}", err) } var buff bytes.Buffer png.Encode(&buff, barcode) totpB64Barcode = base64.StdEncoding.EncodeToString(buff.Bytes()) } if err := b.Core.PersistTOTPKey(ctx, mConfig.ID, entity.ID, keyObject.Secret()); err != nil { return nil, errwrap.Wrapf("failed to persist totp key: {{err}}", err) } entity.MFASecrets[mConfig.ID] = &mfa.Secret{ MethodName: mConfig.Name, Value: &mfa.Secret_TOTPSecret{ TOTPSecret: &mfa.TOTPSecret{ Issuer: totpConfig.Issuer, AccountName: entity.ID, Period: uint32(totpConfig.Period), Algorithm: int32(totpConfig.Algorithm), Digits: int32(totpConfig.Digits), Skew: uint32(totpConfig.Skew), KeySize: uint32(totpConfig.KeySize), }, }, } err = b.Core.identityStore.upsertEntity(ctx, entity, nil, true) if err != nil { return nil, errwrap.Wrapf("failed to persist MFA secret in entity: {{err}}", err) } return &logical.Response{ Data: map[string]interface{}{ "url": totpURL, "barcode": totpB64Barcode, }, }, nil } func (b *LoginMFABackend) mfaConfigReadByMethodID(id string) (map[string]interface{}, error) { mConfig, err := b.MemDBMFAConfigByID(id) if err != nil { return nil, err } if mConfig == nil { return nil, nil } return b.mfaConfigToMap(mConfig) } func (b *LoginMFABackend) mfaMethodList(ctx context.Context, methodType string) ([]string, map[string]interface{}, error) { ns, err := namespace.FromContext(ctx) if err != nil { return nil, nil, err } ws := memdb.NewWatchSet() txn := b.db.Txn(false) var iter memdb.ResultIterator switch { case methodType == "": // get all the configs iter, err = txn.Get(b.methodTable, "id") if err != nil { return nil, nil, fmt.Errorf("failed to fetch iterator for login mfa method configs in memdb: %w", err) } default: // get all the configs for the given type iter, err = txn.Get(b.methodTable, "type", methodType) if err != nil { return nil, nil, fmt.Errorf("failed to fetch iterator for login mfa method configs in memdb: %w", err) } } ws.Add(iter.WatchCh()) var keys []string configInfo := map[string]interface{}{} for { // check for timeouts select { case <-ctx.Done(): return keys, configInfo, nil default: break } raw := iter.Next() if raw == nil { break } config := raw.(*mfa.Config) // return this config if it's in the same ns as the request ns OR it's in a parent ns of the request ns mfaNs, err := b.namespacer.NamespaceByID(ctx, config.NamespaceID) if err != nil { return nil, nil, fmt.Errorf("failed to fetch namespace: %w", err) } // the namespaces have to match, or the config namespace needs to be a parent of the request namespace if !(ns.ID == mfaNs.ID || ns.HasParent(mfaNs)) { continue } keys = append(keys, config.ID) configInfoEntry, err := b.mfaConfigToMap(config) if err != nil { return nil, nil, fmt.Errorf("failed to convert config to map: %w", err) } configInfo[config.ID] = configInfoEntry } return keys, configInfo, nil } func (b *LoginMFABackend) mfaLoginEnforcementList(ctx context.Context) ([]string, map[string]interface{}, error) { ns, err := namespace.FromContext(ctx) if err != nil { return nil, nil, err } ws := memdb.NewWatchSet() txn := b.db.Txn(false) // get all the login enforcements in our namespace iter, err := txn.Get(memDBMFALoginEnforcementsTable, "namespace", ns.ID) if err != nil { return nil, nil, fmt.Errorf("failed to fetch iterator for login enforcement configs in memdb: %w", err) } ws.Add(iter.WatchCh()) var keys []string enforcementInfo := map[string]interface{}{} for { // check for timeouts select { case <-ctx.Done(): return keys, enforcementInfo, nil default: break } raw := iter.Next() if raw == nil { break } config := raw.(*mfa.MFAEnforcementConfig) keys = append(keys, config.Name) configInfoEntry, err := b.mfaLoginEnforcementConfigToMap(config) if err != nil { return nil, nil, fmt.Errorf("failed to convert enforcement to map: %w", err) } enforcementInfo[config.Name] = configInfoEntry } return keys, enforcementInfo, nil } func (b *LoginMFABackend) mfaLoginEnforcementConfigByNameAndNamespace(name, namespaceId string) (map[string]interface{}, error) { eConfig, err := b.MemDBMFALoginEnforcementConfigByNameAndNamespace(name, namespaceId) if err != nil { return nil, err } if eConfig == nil { return nil, nil } return b.mfaLoginEnforcementConfigToMap(eConfig) } func (b *LoginMFABackend) mfaLoginEnforcementConfigToMap(eConfig *mfa.MFAEnforcementConfig) (map[string]interface{}, error) { resp := make(map[string]interface{}) resp["name"] = eConfig.Name ns, err := b.namespacer.NamespaceByID(context.Background(), eConfig.NamespaceID) if err != nil { return nil, err } if ns != nil { resp["namespace_path"] = ns.Path } resp["namespace_id"] = eConfig.NamespaceID resp["mfa_method_ids"] = append([]string{}, eConfig.MFAMethodIDs...) resp["auth_method_accessors"] = append([]string{}, eConfig.AuthMethodAccessors...) resp["auth_method_types"] = append([]string{}, eConfig.AuthMethodTypes...) resp["identity_group_ids"] = append([]string{}, eConfig.IdentityGroupIds...) resp["identity_entity_ids"] = append([]string{}, eConfig.IdentityEntityIDs...) resp["id"] = eConfig.ID return resp, nil } func (b *MFABackend) mfaConfigToMap(mConfig *mfa.Config) (map[string]interface{}, error) { respData := make(map[string]interface{}) switch mConfig.Config.(type) { case *mfa.Config_TOTPConfig: totpConfig := mConfig.GetTOTPConfig() respData["issuer"] = totpConfig.Issuer respData["period"] = totpConfig.Period respData["digits"] = totpConfig.Digits respData["skew"] = totpConfig.Skew respData["key_size"] = totpConfig.KeySize respData["qr_size"] = totpConfig.QRSize respData["algorithm"] = otplib.Algorithm(totpConfig.Algorithm).String() respData["max_validation_attempts"] = totpConfig.MaxValidationAttempts default: return nil, fmt.Errorf("invalid method type %q was persisted, underlying type: %T", mConfig.Type, mConfig.Config) } respData["type"] = mConfig.Type respData["id"] = mConfig.ID respData["name"] = mConfig.Name respData["namespace_id"] = mConfig.NamespaceID ns, err := b.namespacer.NamespaceByID(context.Background(), mConfig.NamespaceID) if err != nil { return nil, err } if ns != nil { respData["namespace_path"] = ns.Path } return respData, nil } func parseTOTPConfig(mConfig *mfa.Config, d *framework.FieldData) error { if mConfig == nil { return fmt.Errorf("config is nil") } if d == nil { return fmt.Errorf("field data is nil") } algorithm := d.Get("algorithm").(string) var keyAlgorithm otplib.Algorithm switch algorithm { case "SHA1": keyAlgorithm = otplib.AlgorithmSHA1 case "SHA256": keyAlgorithm = otplib.AlgorithmSHA256 case "SHA512": keyAlgorithm = otplib.AlgorithmSHA512 default: return fmt.Errorf("unrecognized algorithm") } digits := d.Get("digits").(int) var keyDigits otplib.Digits switch digits { case 6: keyDigits = otplib.DigitsSix case 8: keyDigits = otplib.DigitsEight default: return fmt.Errorf("digits can only be 6 or 8") } period := d.Get("period").(int) if period <= 0 { return fmt.Errorf("period must be greater than zero") } skew := d.Get("skew").(int) switch skew { case 0: case 1: default: return fmt.Errorf("skew must be 0 or 1") } keySize := d.Get("key_size").(int) if keySize <= 0 { return fmt.Errorf("key_size must be greater than zero") } issuer := d.Get("issuer").(string) if issuer == "" { return fmt.Errorf("issuer must be set") } maxValidationAttempt := d.Get("max_validation_attempts").(int) if maxValidationAttempt < 0 { return fmt.Errorf("max_validation_attempts must be greater than zero") } if maxValidationAttempt == 0 { maxValidationAttempt = defaultMaxTOTPValidateAttempts } config := &mfa.TOTPConfig{ Issuer: issuer, Period: uint32(period), Algorithm: int32(keyAlgorithm), Digits: int32(keyDigits), Skew: uint32(skew), KeySize: uint32(keySize), QRSize: int32(d.Get("qr_size").(int)), MaxValidationAttempts: uint32(maxValidationAttempt), } mConfig.Config = &mfa.Config_TOTPConfig{ TOTPConfig: config, } return nil } func (c *Core) validateLoginMFA(ctx context.Context, eConfig *mfa.MFAEnforcementConfig, entity *identity.Entity, requestConnRemoteAddr string, mfaCredsMap logical.MFACreds) error { sanitizedMfaCreds, err := c.loginMFABackend.sanitizeMFACredsWithLoginEnforcementMethodIDs(ctx, mfaCredsMap, eConfig.MFAMethodIDs) if err != nil { return fmt.Errorf("failed to sanitize MFA creds, %w", err) } if len(sanitizedMfaCreds) == 0 && len(eConfig.MFAMethodIDs) > 0 { return fmt.Errorf("login MFA validation failed for methodID: %v", eConfig.MFAMethodIDs) } var retErr error for _, methodID := range eConfig.MFAMethodIDs { // as configID is the same as methodID, and methodID is unique, we can // use it to retrieve the MFACreds mfaCreds, ok := sanitizedMfaCreds[methodID] if !ok || mfaCreds == nil { continue } err := c.validateLoginMFAInternal(ctx, methodID, entity, requestConnRemoteAddr, mfaCreds) if err != nil { retErr = multierror.Append(retErr, err) continue } return nil } return multierror.Append(retErr, fmt.Errorf("login MFA validation failed for methodID: %v", eConfig.MFAMethodIDs)) } func (c *Core) validateLoginMFAInternal(ctx context.Context, methodID string, entity *identity.Entity, reqConnectionRemoteAddress string, mfaCreds []string) (retErr error) { if entity == nil { return fmt.Errorf("entity is nil") } // Get the configuration for the MFA method set in system backend mConfig, err := c.loginMFABackend.MemDBMFAConfigByID(methodID) if err != nil { return fmt.Errorf("failed to read MFA configuration") } if mConfig == nil { return fmt.Errorf("MFA method configuration not present") } mfaFactors, err := parseMfaFactors(mfaCreds) if err != nil { return fmt.Errorf("failed to parse MFA factor, %w", err) } switch mConfig.Type { case mfaMethodTypeTOTP: // Get the MFA secret data required to validate the supplied credentials if entity.MFASecrets == nil { return fmt.Errorf("MFA secret for method ID %q not present in entity %q", mConfig.ID, entity.ID) } entityMFASecret := entity.MFASecrets[mConfig.ID] if entityMFASecret == nil { return fmt.Errorf("MFA secret for method name %q not present in entity %q", mConfig.Name, entity.ID) } return c.validateTOTP(ctx, mfaFactors, entityMFASecret, mConfig.ID, entity.ID, c.loginMFABackend.usedCodes, mConfig.GetTOTPConfig().MaxValidationAttempts) default: return fmt.Errorf("unrecognized MFA type %q", mConfig.Type) } } func (c *Core) buildMFAEnforcementConfigList(ctx context.Context, entity *identity.Entity, reqPath string) ([]*mfa.MFAEnforcementConfig, error) { ns, err := namespace.FromContext(ctx) if err != nil { return nil, fmt.Errorf("failed to get namespace from context. %s, %v", "error", err) } eConfigIter, err := c.loginMFABackend.MemDBMFALoginEnforcementConfigIterator() if err != nil { return nil, err } me := c.router.MatchingMountEntry(ctx, reqPath) if me == nil { return nil, fmt.Errorf("failed to find matching mount entry for path %v", reqPath) } var matchedMfaEnforcementConfig []*mfa.MFAEnforcementConfig // finding the MFAEnforcement config that matches our ns. ns could be root as well ECONFIG_LOOP: for eConfigRaw := eConfigIter.Next(); eConfigRaw != nil; eConfigRaw = eConfigIter.Next() { eConfig := eConfigRaw.(*mfa.MFAEnforcementConfig) // check if this config's ns applies to current req, // i.e. is it the req's ns or an ancestor of req's ns? eConfigNS, err := c.NamespaceByID(ctx, eConfig.NamespaceID) if err != nil { return nil, fmt.Errorf("failed to find the MFAEnforcementConfig namespace") } if eConfig == nil || eConfigNS == nil || (eConfigNS.ID != ns.ID && !ns.HasParent(eConfigNS)) { continue } // if entity is nil, an MFAEnforcementConfig could still be configured // having mount type/accessor if entity != nil { if entity.NamespaceID != ns.ID { return nil, fmt.Errorf("entity namespace ID is different than the current ns ID") } // Check if entityID is in the MFAEnforcement config if strutil.StrListContains(eConfig.IdentityEntityIDs, entity.ID) { matchedMfaEnforcementConfig = append(matchedMfaEnforcementConfig, eConfig) continue } // Retrieve entity groups directGroups, inheritedGroups, err := c.identityStore.groupsByEntityID(entity.ID) if err != nil { return nil, fmt.Errorf("error on retrieving groups by entityID in MFA") } for _, g := range directGroups { if strutil.StrListContains(eConfig.IdentityGroupIds, g.ID) { matchedMfaEnforcementConfig = append(matchedMfaEnforcementConfig, eConfig) continue ECONFIG_LOOP } } for _, g := range inheritedGroups { if strutil.StrListContains(eConfig.IdentityGroupIds, g.ID) { matchedMfaEnforcementConfig = append(matchedMfaEnforcementConfig, eConfig) continue ECONFIG_LOOP } } } for _, acc := range eConfig.AuthMethodAccessors { if me != nil && me.Accessor == acc { matchedMfaEnforcementConfig = append(matchedMfaEnforcementConfig, eConfig) continue ECONFIG_LOOP } } for _, authT := range eConfig.AuthMethodTypes { if me != nil && me.Type == authT { matchedMfaEnforcementConfig = append(matchedMfaEnforcementConfig, eConfig) continue ECONFIG_LOOP } } } return matchedMfaEnforcementConfig, nil } func formatUsername(format string, alias *identity.Alias, entity *identity.Entity) string { if format == "" { return alias.Name } username := format username = strings.ReplaceAll(username, "{{alias.name}}", alias.Name) username = strings.ReplaceAll(username, "{{entity.name}}", entity.Name) for k, v := range alias.Metadata { username = strings.ReplaceAll(username, fmt.Sprintf("{{alias.metadata.%s}}", k), v) } for k, v := range entity.Metadata { username = strings.ReplaceAll(username, fmt.Sprintf("{{entity.metadata.%s}}", k), v) } return username } type MFAFactor struct { passcode string } func parseMfaFactors(creds []string) (*MFAFactor, error) { mfaFactor := &MFAFactor{} for _, cred := range creds { switch { case cred == "": // for the case of push notification continue case strings.HasPrefix(cred, "passcode="): if mfaFactor.passcode != "" { return nil, fmt.Errorf("found multiple passcodes for the same MFA method") } splits := strings.SplitN(cred, "=", 2) if splits[1] == "" { return nil, fmt.Errorf("invalid passcode") } mfaFactor.passcode = splits[1] case strings.Contains(cred, "="): return nil, fmt.Errorf("found an invalid MFA cred: %v", cred) default: // a non-empty cred that does not match the above // means it is a passcode if mfaFactor.passcode != "" { return nil, fmt.Errorf("found multiple passcodes for the same MFA method") } mfaFactor.passcode = cred } } if mfaFactor.passcode == "" { return nil, nil } return mfaFactor, nil } func (c *Core) validateTOTP(ctx context.Context, mfaFactors *MFAFactor, entityMethodSecret *mfa.Secret, configID, entityID string, usedCodes *cache.Cache, maximumValidationAttempts uint32) error { if mfaFactors == nil || mfaFactors.passcode == "" { return fmt.Errorf("MFA credentials not supplied") } passcode := mfaFactors.passcode totpSecret := entityMethodSecret.GetTOTPSecret() if totpSecret == nil { return fmt.Errorf("entity does not contain the TOTP secret") } usedName := fmt.Sprintf("%s_%s", configID, passcode) _, ok := usedCodes.Get(usedName) if ok { return fmt.Errorf("code already used; new code is available in %v seconds", totpSecret.Period) } // The duration in which a passcode is stored in cache to enforce // rate limit on failed totp passcode validation passcodeTTL := time.Duration(int64(time.Second) * int64(totpSecret.Period)) // Enforcing rate limit per MethodID per EntityID rateLimitID := fmt.Sprintf("%s_%s", configID, entityID) numAttempts, _ := usedCodes.Get(rateLimitID) if numAttempts == nil { usedCodes.Set(rateLimitID, uint32(1), passcodeTTL) } else { num, ok := numAttempts.(uint32) if !ok { return fmt.Errorf("invalid counter type returned in TOTP usedCode cache") } if num == maximumValidationAttempts { return fmt.Errorf("maximum TOTP validation attempts %d exceeded the allowed attempts %d. Please try again in %v seconds", num+1, maximumValidationAttempts, passcodeTTL) } err := usedCodes.Increment(rateLimitID, 1) if err != nil { return fmt.Errorf("failed to increment the TOTP code counter") } } key, err := c.fetchTOTPKey(ctx, configID, entityID) if err != nil { return errwrap.Wrapf("error fetching TOTP key: {{err}}", err) } if key == "" { return fmt.Errorf("empty key for entity's TOTP secret") } validateOpts := totplib.ValidateOpts{ Period: uint(totpSecret.Period), Skew: uint(totpSecret.Skew), Digits: otplib.Digits(int(totpSecret.Digits)), Algorithm: otplib.Algorithm(int(totpSecret.Algorithm)), } valid, err := totplib.ValidateCustom(passcode, key, time.Now(), validateOpts) if err != nil && err != otplib.ErrValidateInputInvalidLength { return errwrap.Wrapf("failed to validate TOTP passcode: {{err}}", err) } if !valid { return fmt.Errorf("failed to validate TOTP passcode") } // Take the key skew, add two for behind and in front, and multiply that by // the period to cover the full possibility of the validity of the key validityPeriod := time.Duration(int64(time.Second) * int64(totpSecret.Period) * int64(2+totpSecret.Skew)) // Adding the used code to the cache err = usedCodes.Add(usedName, nil, validityPeriod) if err != nil { return fmt.Errorf("error adding code to used cache: %w", err) } // deleting the cache entry after a successful MFA validation usedCodes.Delete(rateLimitID) return nil } func loginMFAConfigTableSchema() *memdb.TableSchema { return &memdb.TableSchema{ Name: memDBLoginMFAConfigsTable, Indexes: map[string]*memdb.IndexSchema{ "id": { Name: "id", Unique: true, Indexer: &memdb.StringFieldIndex{ Field: "ID", }, }, "namespace_id": { Name: "namespace_id", Unique: false, Indexer: &memdb.StringFieldIndex{ Field: "NamespaceID", }, }, "type": { Name: "type", Unique: false, Indexer: &memdb.StringFieldIndex{ Field: "Type", }, }, "name": { Name: "name", Unique: true, AllowMissing: true, Indexer: &memdb.CompoundIndex{ Indexes: []memdb.Indexer{ &memdb.StringFieldIndex{ Field: "NamespaceID", }, &memdb.StringFieldIndex{ Field: "Name", }, }, }, }, }, } } // turns out every memdb table schema must have an id index func loginEnforcementTableSchema() *memdb.TableSchema { return &memdb.TableSchema{ Name: memDBMFALoginEnforcementsTable, Indexes: map[string]*memdb.IndexSchema{ "id": { Name: "id", Unique: true, Indexer: &memdb.StringFieldIndex{ Field: "ID", }, }, "namespace": { Name: "namespace", Unique: false, Indexer: &memdb.StringFieldIndex{ Field: "NamespaceID", }, }, "nameAndNamespace": { Name: "nameAndNamespace", Unique: true, Indexer: &memdb.CompoundIndex{ Indexes: []memdb.Indexer{ &memdb.StringFieldIndex{ Field: "Name", }, &memdb.StringFieldIndex{ Field: "NamespaceID", }, }, }, }, }, } } func (b *MFABackend) MemDBUpsertMFAConfig(ctx context.Context, mConfig *mfa.Config) error { txn := b.db.Txn(true) defer txn.Abort() err := b.MemDBUpsertMFAConfigInTxn(txn, mConfig) if err != nil { return err } txn.Commit() return nil } func (b *MFABackend) MemDBUpsertMFAConfigInTxn(txn *memdb.Txn, mConfig *mfa.Config) error { if txn == nil { return fmt.Errorf("nil txn") } if mConfig == nil { return fmt.Errorf("config is nil") } mConfigRaw, err := txn.First(b.methodTable, "id", mConfig.ID) if err != nil { return errwrap.Wrapf("failed to lookup MFA config from MemDB using id: {{err}}", err) } if mConfigRaw != nil { err = txn.Delete(b.methodTable, mConfigRaw) if err != nil { return errwrap.Wrapf("failed to delete MFA config from MemDB: {{err}}", err) } } if err := txn.Insert(b.methodTable, mConfig); err != nil { return errwrap.Wrapf("failed to update MFA config into MemDB: {{err}}", err) } return nil } func (b *LoginMFABackend) MemDBUpsertMFALoginEnforcementConfig(ctx context.Context, eConfig *mfa.MFAEnforcementConfig) error { if eConfig == nil { return fmt.Errorf("config is nil") } txn := b.db.Txn(true) defer txn.Abort() eConfigRaw, err := txn.First(memDBMFALoginEnforcementsTable, "nameAndNamespace", eConfig.Name, eConfig.NamespaceID) if err != nil { return fmt.Errorf("failed to lookup MFA login enforcement config from MemDB using name: %w", err) } if eConfigRaw != nil { err = txn.Delete(memDBMFALoginEnforcementsTable, eConfigRaw) if err != nil { return fmt.Errorf("failed to delete MFA login enforcement config from MemDB: %w", err) } } if err := txn.Insert(memDBMFALoginEnforcementsTable, eConfig); err != nil { return fmt.Errorf("failed to update MFA login enforcement config in MemDB: %w", err) } txn.Commit() return nil } func (b *LoginMFABackend) MemDBMFAConfigByIDInTxn(txn *memdb.Txn, mConfigID string) (*mfa.Config, error) { if mConfigID == "" { return nil, fmt.Errorf("missing config id") } if txn == nil { return nil, fmt.Errorf("txn is nil") } mConfigRaw, err := txn.First(b.methodTable, "id", mConfigID) if err != nil { return nil, errwrap.Wrapf("failed to fetch MFA config from memdb using id: {{err}}", err) } if mConfigRaw == nil { return nil, nil } mConfig, ok := mConfigRaw.(*mfa.Config) if !ok { return nil, fmt.Errorf("failed to declare the type of fetched MFA config") } return mConfig.Clone() } func (b *LoginMFABackend) MemDBMFAConfigByID(mConfigID string) (*mfa.Config, error) { if mConfigID == "" { return nil, fmt.Errorf("missing config id") } txn := b.db.Txn(false) return b.MemDBMFAConfigByIDInTxn(txn, mConfigID) } func (b *LoginMFABackend) MemDBMFAConfigByNameInTxn(ctx context.Context, txn *memdb.Txn, mConfigName string) (*mfa.Config, error) { if mConfigName == "" { return nil, fmt.Errorf("missing config name") } if txn == nil { return nil, fmt.Errorf("txn is nil") } ns, err := namespace.FromContext(ctx) if err != nil { return nil, err } mConfigRaw, err := txn.First(b.methodTable, "name", ns.ID, mConfigName) if err != nil { return nil, fmt.Errorf("failed to fetch MFA config from memdb using name: %w", err) } if mConfigRaw == nil { return nil, nil } mConfig, ok := mConfigRaw.(*mfa.Config) if !ok { return nil, fmt.Errorf("failed to declare the type of fetched MFA config") } return mConfig.Clone() } func (b *LoginMFABackend) MemDBMFAConfigByName(ctx context.Context, name string) (*mfa.Config, error) { if name == "" { return nil, fmt.Errorf("missing config name") } txn := b.db.Txn(false) return b.MemDBMFAConfigByNameInTxn(ctx, txn, name) } func (b *LoginMFABackend) MemDBMFALoginEnforcementConfigByNameAndNamespace(name, namespaceId string) (*mfa.MFAEnforcementConfig, error) { if name == "" { return nil, fmt.Errorf("missing config name") } txn := b.db.Txn(false) defer txn.Abort() eConfigRaw, err := txn.First(memDBMFALoginEnforcementsTable, "nameAndNamespace", name, namespaceId) if err != nil { return nil, fmt.Errorf("failed to fetch MFA login enforcement config from memdb using name: %w", err) } if eConfigRaw == nil { return nil, nil } eConfig, ok := eConfigRaw.(*mfa.MFAEnforcementConfig) if !ok { return nil, fmt.Errorf("invalid type for MFA login enforcement config in memdb") } return eConfig.Clone() } func (b *LoginMFABackend) MemDBMFALoginEnforcementConfigIterator() (memdb.ResultIterator, error) { txn := b.db.Txn(false) defer txn.Abort() // List all the MFAEnforcementConfigs it, err := txn.Get(memDBMFALoginEnforcementsTable, "id") if err != nil { return nil, fmt.Errorf("failed to get an iterator over the MFAEnforcementConfig table: %w", err) } return it, nil } func (b *LoginMFABackend) deleteMFALoginEnforcementConfigByNameAndNamespace(ctx context.Context, name, namespaceId string) error { var err error if name == "" { return fmt.Errorf("missing config name") } b.mfaLock.Lock() defer b.mfaLock.Unlock() // delete the config from storage eConfig, err := b.MemDBMFALoginEnforcementConfigByNameAndNamespace(name, namespaceId) if err != nil { return err } if eConfig == nil { return nil } entryIndex := mfaLoginEnforcementPrefix + eConfig.ID barrierView, err := b.Core.barrierViewForNamespace(eConfig.NamespaceID) if err != nil { return err } err = barrierView.Delete(ctx, entryIndex) if err != nil { return err } // create a memdb transaction to delete config txn := b.db.Txn(true) defer txn.Abort() err = txn.Delete(memDBMFALoginEnforcementsTable, eConfig) if err != nil { return fmt.Errorf("failed to delete MFA login enforcement config from memdb: %w", err) } txn.Commit() return nil } func (b *LoginMFABackend) MemDBDeleteMFALoginEnforcementConfigByNameAndNamespace(name, namespaceId, tableName string) error { if name == "" || namespaceId == "" { return nil } txn := b.db.Txn(true) defer txn.Abort() eConfig, err := b.MemDBMFALoginEnforcementConfigByNameAndNamespace(name, namespaceId) if err != nil { return err } if eConfig == nil { return nil } err = txn.Delete(memDBMFALoginEnforcementsTable, eConfig) if err != nil { return err } txn.Commit() return nil } func (b *LoginMFABackend) deleteMFAConfigByMethodID(ctx context.Context, configID, methodType, tableName, prefix string) error { var err error if configID == "" { return fmt.Errorf("missing config id") } b.mfaLock.Lock() defer b.mfaLock.Unlock() eConfigIter, err := b.MemDBMFALoginEnforcementConfigIterator() if err != nil { return err } for eConfigRaw := eConfigIter.Next(); eConfigRaw != nil; eConfigRaw = eConfigIter.Next() { eConfig := eConfigRaw.(*mfa.MFAEnforcementConfig) if strutil.StrListContains(eConfig.MFAMethodIDs, configID) { return fmt.Errorf("methodID is still used by an enforcement configuration with ID: %s", eConfig.ID) } } // Delete the config from storage entryIndex := prefix + configID err = b.Core.systemBarrierView.Delete(ctx, entryIndex) if err != nil { return err } // Create a MemDB transaction to delete config txn := b.db.Txn(true) defer txn.Abort() mConfig, err := b.MemDBMFAConfigByIDInTxn(txn, configID) if err != nil { return err } if mConfig == nil { return nil } if mConfig.Type != methodType { return fmt.Errorf("method type does not match the MFA config type") } mfaNs, err := b.Core.NamespaceByID(ctx, mConfig.NamespaceID) if err != nil { return err } ns, err := namespace.FromContext(ctx) if err != nil { return err } // this logic assumes that the config namespace and the current // namespace should be the same. Note an ancestor of mfaNs is not allowed // to delete methodID if ns.ID != mfaNs.ID { return fmt.Errorf("request namespace does not match method namespace") } if mConfig.Type == "totp" && mConfig.ID != "" { // This is best effort; if they end up hanging around it's okay, they're encrypted anyways if err := logical.ClearView(ctx, NewBarrierView(b.Core.barrier, fmt.Sprintf("%s%s", mfaTOTPKeysPrefix, mConfig.ID))); err != nil { b.mfaLogger.Warn("unable to clear TOTP keys", "method", mConfig.Name, "error", err) } } // Delete the config from MemDB err = b.MemDBDeleteMFAConfigByIDInTxn(txn, configID) if err != nil { return err } txn.Commit() return nil } func (b *LoginMFABackend) MemDBDeleteMFAConfigByID(methodId, tableName string) error { if methodId == "" { return nil } txn := b.db.Txn(true) defer txn.Abort() err := b.MemDBDeleteMFAConfigByIDInTxn(txn, methodId) if err != nil { return err } txn.Commit() return nil } func (b *LoginMFABackend) MemDBDeleteMFAConfigByIDInTxn(txn *memdb.Txn, configID string) error { if configID == "" { return nil } if txn == nil { return fmt.Errorf("txn is nil") } mConfig, err := b.MemDBMFAConfigByIDInTxn(txn, configID) if err != nil { return err } if mConfig == nil { return nil } err = txn.Delete(b.methodTable, mConfig) if err != nil { return errwrap.Wrapf("failed to delete MFA config from memdb: {{err}}", err) } return nil } func (b *LoginMFABackend) putMFAConfigByID(ctx context.Context, mConfig *mfa.Config) error { barrierView, err := b.Core.barrierViewForNamespace(mConfig.NamespaceID) if err != nil { return err } return b.putMFAConfigCommon(ctx, mConfig, loginMFAConfigPrefix, mConfig.ID, barrierView) } func (b *MFABackend) putMFAConfigCommon(ctx context.Context, mConfig *mfa.Config, prefix, suffix string, barrierView *BarrierView) error { entryIndex := prefix + suffix marshaledEntry, err := proto.Marshal(mConfig) if err != nil { return err } return barrierView.Put(ctx, &logical.StorageEntry{ Key: entryIndex, Value: marshaledEntry, }) } func (b *MFABackend) getMFAConfig(ctx context.Context, path string, barrierView *BarrierView) (*mfa.Config, error) { entry, err := barrierView.Get(ctx, path) if err != nil { return nil, err } if entry == nil { return nil, nil } var mConfig mfa.Config err = proto.Unmarshal(entry.Value, &mConfig) if err != nil { return nil, err } return &mConfig, nil } func (b *LoginMFABackend) getMFALoginEnforcementConfig(ctx context.Context, path string, barrierView *BarrierView) (*mfa.MFAEnforcementConfig, error) { entry, err := barrierView.Get(ctx, path) if err != nil { return nil, err } if entry == nil { return nil, nil } var mConfig mfa.MFAEnforcementConfig err = proto.Unmarshal(entry.Value, &mConfig) if err != nil { return nil, err } return &mConfig, nil } func (b *LoginMFABackend) putMFALoginEnforcementConfig(ctx context.Context, eConfig *mfa.MFAEnforcementConfig) error { entryIndex := mfaLoginEnforcementPrefix + eConfig.ID marshaledEntry, err := proto.Marshal(eConfig) if err != nil { return err } barrierView, err := b.Core.barrierViewForNamespace(eConfig.NamespaceID) if err != nil { return err } return barrierView.Put(ctx, &logical.StorageEntry{ Key: entryIndex, Value: marshaledEntry, }) } var mfaHelp = map[string][2]string{ "methods-list": { "Lists all the available MFA methods by their name.", "", }, "totp-generate": { `Generates a TOTP secret for the given method name on the entity of the calling token.`, `This endpoint generates an MFA secret based on the configuration tied to the method name and stores it in the entity of the token making this request.`, }, "totp-admin-generate": { `Generates a TOTP secret for the given method name on the given entity.`, `This endpoint generates an MFA secret based on the configuration tied to the method name and stores it in the entity corresponding to the given entity identifier. This endpoint is used to administratively generate TOTP secrets on entities.`, }, "totp-admin-destroy": { `Deletes the TOTP secret for the given method name on the given entity.`, `This endpoint removes the secret belonging to method name from the entity regardless of the secret type.`, }, "totp-method": { "Defines or updates a TOTP MFA method.", "", }, }