1
0

CLI: Tune plugin version for auth/secret mounts (#17277)

* Add -plugin-version flag to vault auth/secrets tune
* CLI tests for auth/secrets tune
* CLI test for plugin register
* Plugin catalog listing bug where plugins of different type with the same name could be double counted
* Use constant for -plugin-version flag name
This commit is contained in:
Tom Proctor 2022-09-22 20:55:46 +01:00 committed by GitHub
parent 6fc6bb1bb5
commit 21d13633d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 207 additions and 54 deletions

View File

@ -247,7 +247,6 @@ type MountInput struct {
SealWrap bool `json:"seal_wrap" mapstructure:"seal_wrap"`
ExternalEntropyAccess bool `json:"external_entropy_access" mapstructure:"external_entropy_access"`
Options map[string]string `json:"options"`
PluginVersion string `json:"plugin_version,omitempty"`
// Deprecated: Newer server responses should be returning this information in the
// Type field (json: "type") instead.
@ -267,6 +266,7 @@ type MountConfigInput struct {
AllowedResponseHeaders []string `json:"allowed_response_headers,omitempty" mapstructure:"allowed_response_headers"`
TokenType string `json:"token_type,omitempty" mapstructure:"token_type"`
AllowedManagedKeys []string `json:"allowed_managed_keys,omitempty" mapstructure:"allowed_managed_keys"`
PluginVersion string `json:"plugin_version,omitempty"`
// Deprecated: This field will always be blank for newer server responses.
PluginName string `json:"plugin_name,omitempty" mapstructure:"plugin_name"`

View File

@ -201,7 +201,7 @@ func (c *AuthEnableCommand) Flags() *FlagSets {
})
f.StringVar(&StringVar{
Name: "plugin-version",
Name: flagNamePluginVersion,
Target: &c.flagPluginVersion,
Default: "",
Usage: "Select the semantic version of the plugin to enable.",
@ -270,7 +270,6 @@ func (c *AuthEnableCommand) Run(args []string) int {
authOpts := &api.EnableAuthOptions{
Type: authType,
PluginVersion: c.flagPluginVersion,
Description: c.flagDescription,
Local: c.flagLocal,
SealWrap: c.flagSealWrap,
@ -307,6 +306,10 @@ func (c *AuthEnableCommand) Run(args []string) int {
if fl.Name == flagNameTokenType {
authOpts.Config.TokenType = c.flagTokenType
}
if fl.Name == flagNamePluginVersion {
authOpts.Config.PluginVersion = c.flagPluginVersion
}
})
if err := client.Sys().EnableAuthWithOptions(authPath, authOpts); err != nil {

View File

@ -31,6 +31,7 @@ type AuthTuneCommand struct {
flagOptions map[string]string
flagTokenType string
flagVersion int
flagPluginVersion string
}
func (c *AuthTuneCommand) Synopsis() string {
@ -144,6 +145,14 @@ func (c *AuthTuneCommand) Flags() *FlagSets {
Usage: "Select the version of the auth method to run. Not supported by all auth methods.",
})
f.StringVar(&StringVar{
Name: flagNamePluginVersion,
Target: &c.flagPluginVersion,
Default: "",
Usage: "Select the semantic version of the plugin to run. The new version must be registered in " +
"the plugin catalog, and will not start running until the plugin is reloaded.",
})
return set
}
@ -221,6 +230,10 @@ func (c *AuthTuneCommand) Run(args []string) int {
if fl.Name == flagNameTokenType {
mountConfigInput.TokenType = c.flagTokenType
}
if fl.Name == flagNamePluginVersion {
mountConfigInput.PluginVersion = c.flagPluginVersion
}
})
// Append /auth (since that's where auths live) and a trailing slash to

View File

@ -6,6 +6,8 @@ import (
"github.com/go-test/deep"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/vault"
"github.com/mitchellh/cli"
)
@ -74,7 +76,10 @@ func TestAuthTuneCommand_Run(t *testing.T) {
t.Run("integration", func(t *testing.T) {
t.Run("flags_all", func(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
pluginDir, cleanup := vault.MakeTestPluginDir(t)
defer cleanup(t)
client, _, closer := testVaultServerPluginDir(t, pluginDir)
defer closer()
ui, cmd := testAuthTuneCommand(t)
@ -87,6 +92,21 @@ func TestAuthTuneCommand_Run(t *testing.T) {
t.Fatal(err)
}
auths, err := client.Sys().ListAuth()
if err != nil {
t.Fatal(err)
}
mountInfo, ok := auths["my-auth/"]
if !ok {
t.Fatalf("expected mount to exist: %#v", auths)
}
if exp := ""; mountInfo.PluginVersion != exp {
t.Errorf("expected %q to be %q", mountInfo.PluginVersion, exp)
}
_, _, version := testPluginCreateAndRegisterVersioned(t, client, pluginDir, "userpass", consts.PluginTypeCredential)
code := cmd.Run([]string{
"-description", "new description",
"-default-lease-ttl", "30m",
@ -97,6 +117,7 @@ func TestAuthTuneCommand_Run(t *testing.T) {
"-passthrough-request-headers", "www-authentication",
"-allowed-response-headers", "authorization,www-authentication",
"-listing-visibility", "unauth",
"-plugin-version", version,
"my-auth/",
})
if exp := 0; code != exp {
@ -109,12 +130,12 @@ func TestAuthTuneCommand_Run(t *testing.T) {
t.Errorf("expected %q to contain %q", combined, expected)
}
auths, err := client.Sys().ListAuth()
auths, err = client.Sys().ListAuth()
if err != nil {
t.Fatal(err)
}
mountInfo, ok := auths["my-auth/"]
mountInfo, ok = auths["my-auth/"]
if !ok {
t.Fatalf("expected auth to exist")
}
@ -124,6 +145,9 @@ func TestAuthTuneCommand_Run(t *testing.T) {
if exp := "userpass"; mountInfo.Type != exp {
t.Errorf("expected %q to be %q", mountInfo.Type, exp)
}
if exp := version; mountInfo.PluginVersion != exp {
t.Errorf("expected %q to be %q", mountInfo.PluginVersion, exp)
}
if exp := 1800; mountInfo.Config.DefaultLeaseTTL != exp {
t.Errorf("expected %d to be %d", mountInfo.Config.DefaultLeaseTTL, exp)
}

View File

@ -124,6 +124,8 @@ const (
flagNameTokenType = "token-type"
// flagNameAllowedManagedKeys is the flag name used for auth/secrets enable
flagNameAllowedManagedKeys = "allowed-managed-keys"
// flagNamePluginVersion selects what version of a plugin should be used.
flagNamePluginVersion = "plugin-version"
)
var (

View File

@ -1,6 +1,8 @@
package command
import (
"reflect"
"sort"
"strings"
"testing"
@ -124,6 +126,75 @@ func TestPluginRegisterCommand_Run(t *testing.T) {
}
})
t.Run("integration with version", func(t *testing.T) {
t.Parallel()
pluginDir, cleanup := vault.MakeTestPluginDir(t)
defer cleanup(t)
client, _, closer := testVaultServerPluginDir(t, pluginDir)
defer closer()
const pluginName = "my-plugin"
versions := []string{"v1.0.0", "v2.0.1"}
_, sha256Sum := testPluginCreate(t, pluginDir, pluginName)
types := []consts.PluginType{consts.PluginTypeCredential, consts.PluginTypeDatabase, consts.PluginTypeSecrets}
for _, typ := range types {
for _, version := range versions {
ui, cmd := testPluginRegisterCommand(t)
cmd.client = client
code := cmd.Run([]string{
"-version=" + version,
"-sha256=" + sha256Sum,
typ.String(),
pluginName,
})
if exp := 0; code != exp {
t.Errorf("expected %d to be %d", code, exp)
}
expected := "Success! Registered plugin: my-plugin"
combined := ui.OutputWriter.String() + ui.ErrorWriter.String()
if !strings.Contains(combined, expected) {
t.Errorf("expected %q to contain %q", combined, expected)
}
}
}
resp, err := client.Sys().ListPlugins(&api.ListPluginsInput{
Type: consts.PluginTypeUnknown,
})
if err != nil {
t.Fatal(err)
}
found := make(map[consts.PluginType]int)
versionsFound := make(map[consts.PluginType][]string)
for _, p := range resp.Details {
if p.Name == pluginName {
typ, err := consts.ParsePluginType(p.Type)
if err != nil {
t.Fatal(err)
}
found[typ]++
versionsFound[typ] = append(versionsFound[typ], p.Version)
}
}
for _, typ := range types {
if found[typ] != 2 {
t.Fatalf("expected %q to be found 2 times, but found it %d times for %s type in %#v", pluginName, found[typ], typ.String(), resp.Details)
}
sort.Strings(versions)
sort.Strings(versionsFound[typ])
if !reflect.DeepEqual(versions, versionsFound[typ]) {
t.Fatalf("expected %v versions but got %v", versions, versionsFound[typ])
}
}
})
t.Run("communication_failure", func(t *testing.T) {
t.Parallel()

View File

@ -32,6 +32,7 @@ type SecretsEnableCommand struct {
flagAllowedResponseHeaders []string
flagForceNoCache bool
flagPluginName string
flagPluginVersion string
flagOptions map[string]string
flagLocal bool
flagSealWrap bool
@ -173,6 +174,13 @@ func (c *SecretsEnableCommand) Flags() *FlagSets {
"exist in Vault's plugin catalog.",
})
f.StringVar(&StringVar{
Name: flagNamePluginVersion,
Target: &c.flagPluginVersion,
Default: "",
Usage: "Select the semantic version of the plugin to enable.",
})
f.StringMapVar(&StringMapVar{
Name: "options",
Target: &c.flagOptions,
@ -320,6 +328,10 @@ func (c *SecretsEnableCommand) Run(args []string) int {
if fl.Name == flagNameAllowedManagedKeys {
mountInput.Config.AllowedManagedKeys = c.flagAllowedManagedKeys
}
if fl.Name == flagNamePluginVersion {
mountInput.Config.PluginVersion = c.flagPluginVersion
}
})
if err := client.Sys().Mount(mountPath, mountInput); err != nil {

View File

@ -30,6 +30,7 @@ type SecretsTuneCommand struct {
flagAllowedResponseHeaders []string
flagOptions map[string]string
flagVersion int
flagPluginVersion string
flagAllowedManagedKeys []string
}
@ -146,6 +147,14 @@ func (c *SecretsTuneCommand) Flags() *FlagSets {
"each time with 1 key.",
})
f.StringVar(&StringVar{
Name: flagNamePluginVersion,
Target: &c.flagPluginVersion,
Default: "",
Usage: "Select the semantic version of the plugin to run. The new version must be registered in " +
"the plugin catalog, and will not start running until the plugin is reloaded.",
})
return set
}
@ -226,6 +235,10 @@ func (c *SecretsTuneCommand) Run(args []string) int {
if fl.Name == flagNameAllowedManagedKeys {
mountConfigInput.AllowedManagedKeys = c.flagAllowedManagedKeys
}
if fl.Name == flagNamePluginVersion {
mountConfigInput.PluginVersion = c.flagPluginVersion
}
})
if err := client.Sys().TuneMount(mountPath, mountConfigInput); err != nil {

View File

@ -6,6 +6,8 @@ import (
"github.com/go-test/deep"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/vault"
"github.com/mitchellh/cli"
)
@ -148,7 +150,10 @@ func TestSecretsTuneCommand_Run(t *testing.T) {
t.Run("integration", func(t *testing.T) {
t.Run("flags_all", func(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
pluginDir, cleanup := vault.MakeTestPluginDir(t)
defer cleanup(t)
client, _, closer := testVaultServerPluginDir(t, pluginDir)
defer closer()
ui, cmd := testSecretsTuneCommand(t)
@ -161,6 +166,21 @@ func TestSecretsTuneCommand_Run(t *testing.T) {
t.Fatal(err)
}
mounts, err := client.Sys().ListMounts()
if err != nil {
t.Fatal(err)
}
mountInfo, ok := mounts["mount_tune_integration/"]
if !ok {
t.Fatalf("expected mount to exist")
}
if exp := ""; mountInfo.PluginVersion != exp {
t.Errorf("expected %q to be %q", mountInfo.PluginVersion, exp)
}
_, _, version := testPluginCreateAndRegisterVersioned(t, client, pluginDir, "pki", consts.PluginTypeSecrets)
code := cmd.Run([]string{
"-description", "new description",
"-default-lease-ttl", "30m",
@ -172,6 +192,7 @@ func TestSecretsTuneCommand_Run(t *testing.T) {
"-allowed-response-headers", "authorization,www-authentication",
"-allowed-managed-keys", "key1,key2",
"-listing-visibility", "unauth",
"-plugin-version", version,
"mount_tune_integration/",
})
if exp := 0; code != exp {
@ -184,12 +205,12 @@ func TestSecretsTuneCommand_Run(t *testing.T) {
t.Errorf("expected %q to contain %q", combined, expected)
}
mounts, err := client.Sys().ListMounts()
mounts, err = client.Sys().ListMounts()
if err != nil {
t.Fatal(err)
}
mountInfo, ok := mounts["mount_tune_integration/"]
mountInfo, ok = mounts["mount_tune_integration/"]
if !ok {
t.Fatalf("expected mount to exist")
}
@ -199,6 +220,9 @@ func TestSecretsTuneCommand_Run(t *testing.T) {
if exp := "pki"; mountInfo.Type != exp {
t.Errorf("expected %q to be %q", mountInfo.Type, exp)
}
if exp := version; mountInfo.PluginVersion != exp {
t.Errorf("expected %q to be %q", mountInfo.PluginVersion, exp)
}
if exp := 1800; mountInfo.Config.DefaultLeaseTTL != exp {
t.Errorf("expected %d to be %d", mountInfo.Config.DefaultLeaseTTL, exp)
}

View File

@ -318,8 +318,10 @@ func TestCore_EnableExternalCredentialPlugin_NoVersionOnRegister(t *testing.T) {
req := logical.TestRequest(t, logical.UpdateOperation, mountTable(tc.pluginType))
req.Data = map[string]interface{}{
"type": pluginName,
"plugin_version": "v1.0.0",
"type": pluginName,
"config": map[string]interface{}{
"plugin_version": "v1.0.0",
},
}
resp, _ := c.systemBackend.HandleRequest(namespace.RootContext(nil), req)
if resp == nil || !resp.IsError() || !strings.Contains(resp.Error().Error(), ErrPluginNotFound.Error()) {
@ -379,22 +381,7 @@ func TestExternalPlugin_getBackendTypeVersion(t *testing.T) {
} {
t.Run(name, func(t *testing.T) {
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, tc.setRunningVersion)
d := &framework.FieldData{
Raw: map[string]interface{}{
"name": pluginName,
"sha256": pluginSHA256,
"version": tc.setRunningVersion,
"command": pluginName,
},
Schema: c.systemBackend.pluginsCatalogCRUDPath().Fields,
}
resp, err := c.systemBackend.handlePluginCatalogUpdate(context.Background(), nil, d)
if err != nil {
t.Fatal(err)
}
if resp.Error() != nil {
t.Fatalf("%#v", resp)
}
registerPlugin(t, c.systemBackend, pluginName, tc.pluginType.String(), tc.setRunningVersion, pluginSHA256)
shaBytes, _ := hex.DecodeString(pluginSHA256)
commandFull := filepath.Join(c.pluginCatalog.directory, pluginName)
@ -407,6 +394,7 @@ func TestExternalPlugin_getBackendTypeVersion(t *testing.T) {
}
var version logical.PluginVersion
var err error
if tc.pluginType == consts.PluginTypeDatabase {
version, err = c.pluginCatalog.getDatabaseRunningVersion(context.Background(), entry)
} else {
@ -447,7 +435,9 @@ func mountPlugin(t *testing.T, sys *SystemBackend, pluginName string, pluginType
"type": pluginName,
}
if version != "" {
req.Data["plugin_version"] = version
req.Data["config"] = map[string]interface{}{
"plugin_version": version,
}
}
resp, err := sys.HandleRequest(namespace.RootContext(nil), req)
if err != nil {

View File

@ -1001,10 +1001,6 @@ func (b *SystemBackend) handleMount(ctx context.Context, req *logical.Request, d
sealWrap := data.Get("seal_wrap").(bool)
externalEntropyAccess := data.Get("external_entropy_access").(bool)
options := data.Get("options").(map[string]string)
var version string
if pluginVersionRaw, ok := data.GetOk("plugin_version"); ok {
version = pluginVersionRaw.(string)
}
var config MountConfig
var apiConfig APIMountConfig
@ -1110,6 +1106,7 @@ func (b *SystemBackend) handleMount(ctx context.Context, req *logical.Request, d
}
}
version := apiConfig.PluginVersion
switch version {
case "":
var err error
@ -2349,10 +2346,6 @@ func (b *SystemBackend) handleEnableAuth(ctx context.Context, req *logical.Reque
sealWrap := data.Get("seal_wrap").(bool)
externalEntropyAccess := data.Get("external_entropy_access").(bool)
options := data.Get("options").(map[string]string)
var version string
if pluginVersionRaw, ok := data.GetOk("plugin_version"); ok {
version = pluginVersionRaw.(string)
}
var config MountConfig
var apiConfig APIMountConfig
@ -2446,6 +2439,7 @@ func (b *SystemBackend) handleEnableAuth(ctx context.Context, req *logical.Reque
}
}
version := apiConfig.PluginVersion
switch version {
case "":
var err error

View File

@ -368,6 +368,7 @@ type APIMountConfig struct {
AllowedResponseHeaders []string `json:"allowed_response_headers,omitempty" structs:"allowed_response_headers" mapstructure:"allowed_response_headers"`
TokenType string `json:"token_type" structs:"token_type" mapstructure:"token_type"`
AllowedManagedKeys []string `json:"allowed_managed_keys,omitempty" mapstructure:"allowed_managed_keys"`
PluginVersion string `json:"plugin_version,omitempty" mapstructure:"plugin_version"`
// PluginName is the name of the plugin registered in the catalog.
//

View File

@ -922,6 +922,7 @@ func (c *PluginCatalog) listInternal(ctx context.Context, pluginType consts.Plug
// Users don't expect to see the plugin type, so we need to strip that here.
var normalizedName, version string
var semanticVersion *semver.Version
storedType := consts.PluginTypeUnknown
parts := strings.Split(plugin, "/")
switch len(parts) {
@ -933,7 +934,7 @@ func (c *PluginCatalog) listInternal(ctx context.Context, pluginType consts.Plug
return nil, err
}
case 2: // Unversioned
if isPluginType(parts[0]) {
if storedType, err = consts.ParsePluginType(parts[0]); err == nil {
normalizedName = parts[1]
// Use 0.0.0 to ensure unversioned is sorted as the oldest version.
semanticVersion, err = semver.NewVersion("0.0.0")
@ -941,13 +942,17 @@ func (c *PluginCatalog) listInternal(ctx context.Context, pluginType consts.Plug
return nil, err
}
} else {
return nil, fmt.Errorf("unknown plugin type in plugin catalog: %s", plugin)
return nil, fmt.Errorf("unknown plugin type in plugin catalog: %s: %w", plugin, err)
}
case 3: // Versioned, with type
if !includeVersioned {
continue
}
storedType, err = consts.ParsePluginType(parts[0])
if err != nil {
return nil, fmt.Errorf("unexpected error parsing plugin type from plugin catalog entry %q: %w", plugin, err)
}
normalizedName, version = parts[1], parts[2]
semanticVersion, err = semver.NewVersion(version)
if err != nil {
@ -958,18 +963,24 @@ func (c *PluginCatalog) listInternal(ctx context.Context, pluginType consts.Plug
}
// Only list user-added plugins if they're of the given type.
if entry, err := c.get(ctx, normalizedName, pluginType, version); err == nil && entry != nil {
result = append(result, pluginutil.VersionedPlugin{
Name: normalizedName,
Type: pluginType.String(),
Version: version,
SHA256: hex.EncodeToString(entry.Sha256),
SemanticVersion: semanticVersion,
})
if storedType != consts.PluginTypeUnknown && storedType != pluginType {
continue
}
entry, err := c.get(ctx, normalizedName, pluginType, version)
if err != nil || entry == nil {
continue
}
if version == "" {
unversionedPlugins[normalizedName] = struct{}{}
}
result = append(result, pluginutil.VersionedPlugin{
Name: normalizedName,
Type: pluginType.String(),
Version: version,
SHA256: hex.EncodeToString(entry.Sha256),
SemanticVersion: semanticVersion,
})
if version == "" {
unversionedPlugins[normalizedName] = struct{}{}
}
}
@ -999,8 +1010,3 @@ func (c *PluginCatalog) listInternal(ctx context.Context, pluginType consts.Plug
return result, nil
}
func isPluginType(s string) bool {
_, err := consts.ParsePluginType(s)
return err == nil
}