diff --git a/cmd/login.go b/cmd/login.go index 187f266..0890ae9 100644 --- a/cmd/login.go +++ b/cmd/login.go @@ -2,16 +2,20 @@ package cmd import ( opaal "davidallendj/opaal/internal" - "davidallendj/opaal/internal/db" + cache "davidallendj/opaal/internal/cache/sqlite" + "davidallendj/opaal/internal/oauth" "davidallendj/opaal/internal/oidc" "fmt" "os" + "slices" "github.com/spf13/cobra" ) var ( - client opaal.Client + client oauth.Client + target string = "" + targetIndex int = -1 ) var loginCmd = &cobra.Command{ @@ -21,24 +25,57 @@ var loginCmd = &cobra.Command{ for { // try and find client with valid identity provider config var provider *oidc.IdentityProvider - for _, c := range config.Authentication.Clients { - // try to get identity provider info locally first - _, err := db.GetIdentityProvider(config.Options.CachePath, c.Issuer) - if err != nil && !config.Options.LocalOnly { - fmt.Printf("fetching config from issuer: %v\n", c.Issuer) - // try to get info remotely by fetching - provider, err = oidc.FetchServerConfig(c.Issuer) - if err != nil { - fmt.Printf("failed to fetch server config: %v\n", err) - continue + if target != "" { + // only try to use client with name give + index := slices.IndexFunc(config.Authentication.Clients, func(c oauth.Client) bool { + return target == c.Name + }) + if index < 0 { + fmt.Printf("could not find the target client listed by name") + os.Exit(1) + } + client := config.Authentication.Clients[index] + _, err := cache.GetIdentityProvider(config.Options.CachePath, client.Issuer) + if err != nil { + + } + + } else if targetIndex >= 0 { + // only try to use client by index + targetCount := len(config.Authentication.Clients) - 1 + if targetIndex > targetCount { + fmt.Printf("target index out of range (found %d)", targetCount) + } + client := config.Authentication.Clients[targetIndex] + _, err := cache.GetIdentityProvider(config.Options.CachePath, client.Issuer) + if err != nil { + + } + } else { + for _, c := range config.Authentication.Clients { + // try to get identity provider info locally first + _, err := cache.GetIdentityProvider(config.Options.CachePath, c.Issuer) + if err != nil && !config.Options.CacheOnly { + fmt.Printf("fetching config from issuer: %v\n", c.Issuer) + // try to get info remotely by fetching + provider, err = oidc.FetchServerConfig(c.Issuer) + if err != nil { + fmt.Printf("failed to fetch server config: %v\n", err) + continue + } + client = c + // fetch the provider's JWKS + err := provider.FetchJwks() + if err != nil { + fmt.Printf("failed to fetch JWKS: %v\n", err) + } + break } - client = c - // fetch the provider's JWKS - err := provider.FetchJwks() - if err != nil { - fmt.Printf("failed to fetch JWKS: %v\n", err) + // only test the first if --run-all flag is not set + if !config.Authentication.TestAllClients { + fmt.Printf("stopping after first test...\n\n\n") + break } - break } } @@ -66,10 +103,13 @@ func init() { loginCmd.Flags().StringVar(&config.Server.Host, "server.host", config.Server.Host, "set the listening host") loginCmd.Flags().IntVar(&config.Server.Port, "server.port", config.Server.Port, "set the listening port") loginCmd.Flags().BoolVar(&config.Options.OpenBrowser, "open-browser", config.Options.OpenBrowser, "automatically open link in browser") - loginCmd.Flags().BoolVar(&config.Options.DecodeIdToken, "decode-id-token", config.Options.DecodeIdToken, "decode and print ID token from identity provider") - loginCmd.Flags().BoolVar(&config.Options.DecodeAccessToken, "decore-access-token", config.Options.DecodeAccessToken, "decode and print access token from authorization server") loginCmd.Flags().BoolVar(&config.Options.RunOnce, "once", config.Options.RunOnce, "set whether to run login once and exit") loginCmd.Flags().StringVar(&config.Options.FlowType, "flow", config.Options.FlowType, "set the grant-type/authorization flow") - loginCmd.Flags().BoolVar(&config.Options.LocalOnly, "local", config.Options.LocalOnly, "only fetch identity provider configs stored locally") + loginCmd.Flags().BoolVar(&config.Options.CacheOnly, "local", config.Options.CacheOnly, "only fetch identity provider configs stored locally") + loginCmd.Flags().BoolVar(&config.Authentication.TestAllClients, "test-all", config.Authentication.TestAllClients, "test all clients in config for a valid provider") + loginCmd.Flags().StringVar(&target, "target", "", "set target client to use from config by name") + loginCmd.Flags().IntVar(&targetIndex, "index", -1, "set target client to use from config by index") + loginCmd.MarkFlagsMutuallyExclusive("target", "index") + rootCmd.AddCommand(loginCmd) } diff --git a/cmd/root.go b/cmd/root.go index 4ec5905..2d3e916 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -30,7 +30,7 @@ func Execute() { func init() { cobra.OnInitialize(initConfig) - rootCmd.PersistentFlags().StringVar(&confPath, "config", "", "set the config path") + rootCmd.PersistentFlags().StringVarP(&confPath, "config", "c", "", "set the config path") rootCmd.PersistentFlags().StringVar(&config.Options.CachePath, "cache", "", "set the cache path") }