diff --git a/cmd/config.go b/cmd/config.go index e1d8426..a37fec0 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -1,6 +1,75 @@ package cmd -import "github.com/spf13/cobra" +import ( + "davidallendj/oidc-auth/internal/util" + "log" + "os" + "path/filepath" + + "github.com/spf13/cobra" + yaml "gopkg.in/yaml.v2" +) + +type Config struct { + Host string `yaml:"host"` + Port int `yaml:"port"` + RedirectUri []string `yaml:"redirect-uri"` + State string `yaml:"state"` + ResponseType string `yaml:"response-type"` + Scope []string `yaml:"scope"` + ClientId string `yaml:"client.id"` + ClientSecret string `yaml:"client.secret"` + OIDCHost string `yaml:"oidc.host"` + OIDCPort int `yaml:"oidc.port"` +} + +func NewConfig() Config { + return Config{ + Host: "127.0.0.1", + Port: 3333, + RedirectUri: []string{""}, + State: util.RandomString(20), + ResponseType: "code", + Scope: []string{"openid", "profile", "email"}, + ClientId: "", + ClientSecret: "", + OIDCHost: "127.0.0.1", + OIDCPort: 80, + } +} + +func LoadConfig(path string) Config { + var c Config = NewConfig() + file, err := os.ReadFile(path) + if err != nil { + log.Printf("failed to read config file: %v\n", err) + return c + } + err = yaml.Unmarshal(file, &c) + if err != nil { + log.Fatalf("failed to unmarshal config: %v\n", err) + return c + } + return c +} + +func SaveDefaultConfig(path string) { + path = filepath.Clean(path) + if path == "" || path == "." { + path = "config.yaml" + } + var c = NewConfig() + data, err := yaml.Marshal(c) + if err != nil { + log.Printf("failed to marshal config: %v\n", err) + return + } + err = os.WriteFile(path, data, os.ModePerm) + if err != nil { + log.Printf("failed to write default config file: %v\n", err) + return + } +} var configCmd = &cobra.Command{ Use: "config", @@ -8,7 +77,7 @@ var configCmd = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { // create a new config at all args (paths) for _, path := range args { - _ = path + SaveDefaultConfig(path) } }, } diff --git a/cmd/login.go b/cmd/login.go index 9232006..66d3d09 100644 --- a/cmd/login.go +++ b/cmd/login.go @@ -1,7 +1,6 @@ package cmd import ( - "davidallendj/oidc-auth/internal/oauth" "davidallendj/oidc-auth/internal/oidc" "davidallendj/oidc-auth/internal/server" "davidallendj/oidc-auth/internal/util" @@ -13,36 +12,33 @@ import ( "github.com/spf13/cobra" ) -var ( - host string - port int - redirectUri = []string{""} - state = "" - responseType = "code" - scope = []string{"email", "profile", "openid"} - client oauth.Client -) - var loginCmd = &cobra.Command{ Use: "login", Short: "Start the login flow", Run: func(cmd *cobra.Command, args []string) { + if configPath != "" { + config = LoadConfig(configPath) + } else { + config = NewConfig() + } oidcProvider := oidc.NewOIDCProvider() + oidcProvider.Host = config.OIDCHost + oidcProvider.Port = config.OIDCPort var authorizationUrl = util.BuildAuthorizationUrl( oidcProvider.GetAuthorizeUrl(), - client.Id, - redirectUri, - util.RandomString(20), - responseType, - []string{"email", "profile", "openid"}, + config.ClientId, + config.RedirectUri, + config.State, + config.ResponseType, + config.Scope, ) // print the authorization URL for the user to log in fmt.Printf("Login with identity provider: %s\n", authorizationUrl) - // start a HTTP server to listen for callback responses + // authorize oauth client and listen for callback from provider fmt.Printf("Waiting for response from OIDC provider...\n") - err := server.Start(host, port) + code, err := server.WaitForAuthorizationCode(config.Host, config.Port) if errors.Is(err, http.ErrServerClosed) { fmt.Printf("server closed\n") } else if err != nil { @@ -50,7 +46,8 @@ var loginCmd = &cobra.Command{ os.Exit(1) } - // extract code from response and exchange for bearer token + // use code from response and exchange for bearer token + server.FetchToken(code, oidcProvider.GetTokenUrl(), config.ClientId, config.ClientSecret, config.State, config.RedirectUri) // extract ID token and save user info @@ -61,12 +58,12 @@ var loginCmd = &cobra.Command{ } func init() { - loginCmd.Flags().StringVar(&client.Id, "client.id", "", "set the client ID") - loginCmd.Flags().StringSliceVar(&redirectUri, "redirect-uri", []string{""}, "set the redirect URI") - loginCmd.Flags().StringVar(&responseType, "response-type", "code", "set the response-type") - loginCmd.Flags().StringSliceVar(&scope, "scope", []string{"openid", "email"}, "set the scopes") - loginCmd.Flags().StringVar(&state, "state", util.RandomString(20), "set the state") - loginCmd.Flags().StringVar(&host, "host", "127.0.0.1", "set the listening host") - loginCmd.Flags().IntVar(&port, "port", 3333, "set the listening port") + loginCmd.Flags().StringVar(&config.ClientId, "client.id", config.ClientId, "set the client ID") + loginCmd.Flags().StringSliceVar(&config.RedirectUri, "redirect-uri", config.RedirectUri, "set the redirect URI") + loginCmd.Flags().StringVar(&config.ResponseType, "response-type", config.ResponseType, "set the response-type") + loginCmd.Flags().StringSliceVar(&config.Scope, "scope", config.Scope, "set the scopes") + loginCmd.Flags().StringVar(&config.State, "state", config.State, "set the state") + loginCmd.Flags().StringVar(&config.Host, "host", config.Host, "set the listening host") + loginCmd.Flags().IntVar(&config.Port, "port", config.Port, "set the listening port") rootCmd.AddCommand(loginCmd) } diff --git a/cmd/root.go b/cmd/root.go index 37c85fc..9a28af6 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -7,7 +7,10 @@ import ( "github.com/spf13/cobra" ) -var configPath = "" +var ( + configPath = "" + config Config +) var rootCmd = &cobra.Command{ Use: "oidc", Short: "An experimental OIDC helper tool for handling logins", @@ -18,11 +21,11 @@ var rootCmd = &cobra.Command{ func Execute() { if err := rootCmd.Execute(); err != nil { - fmt.Fprintf(os.Stderr, "Whoops. There was an error while executing your CLI '%s'", err) + fmt.Fprintf(os.Stderr, "failed to start CLI: %s", err) os.Exit(1) } } func init() { - rootCmd.Flags().StringVar(&configPath, "config", "", "set the config path") + rootCmd.PersistentFlags().StringVar(&configPath, "config", "", "set the config path") } diff --git a/go.mod b/go.mod index f823b7a..cb17d59 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,10 @@ module davidallendj/oidc-auth go 1.22.0 -require github.com/spf13/cobra v1.8.0 +require ( + github.com/spf13/cobra v1.8.0 + gopkg.in/yaml.v2 v2.4.0 +) require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/go.sum b/go.sum index d0e8c2c..d90dbd8 100644 --- a/go.sum +++ b/go.sum @@ -6,5 +6,8 @@ github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/oidc/oidc-auth b/internal/oidc/oidc-auth deleted file mode 100755 index f26efe6..0000000 Binary files a/internal/oidc/oidc-auth and /dev/null differ diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go index 6e15761..4f1821e 100644 --- a/internal/oidc/oidc.go +++ b/internal/oidc/oidc.go @@ -1,7 +1,10 @@ package oidc +import "fmt" + type OpenIDConnectProvider struct { Host string + Port int AuthorizeEndpoint string TokenEndpoint string ConfigEndpoint string @@ -9,17 +12,24 @@ type OpenIDConnectProvider struct { func NewOIDCProvider() *OpenIDConnectProvider { return &OpenIDConnectProvider{ - Host: "https://gitlab.newmexicoconsortium.org", + Host: "127.0.0.1", + Port: 80, AuthorizeEndpoint: "/oauth/authorize", TokenEndpoint: "/oauth/token", } } func (oidc *OpenIDConnectProvider) GetAuthorizeUrl() string { + if oidc.Port != 80 { + return fmt.Sprintf("%s:%d", oidc.Host, oidc.Port) + oidc.AuthorizeEndpoint + } return oidc.Host + oidc.AuthorizeEndpoint } func (oidc *OpenIDConnectProvider) GetTokenUrl() string { + if oidc.Port != 80 { + return fmt.Sprintf("%s:%d", oidc.Host, oidc.Port) + oidc.TokenEndpoint + } return oidc.Host + oidc.TokenEndpoint } diff --git a/internal/server/server.go b/internal/server/server.go index 62a1a8d..663b48a 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1,16 +1,48 @@ package server import ( + "davidallendj/oidc-auth/internal/util" "fmt" "net/http" + "net/url" + "os" + "strings" ) -func Start(host string, port int) error { - http.HandleFunc("/oauth/callback", getAuthorizationCode) - err := http.ListenAndServe(host+":"+fmt.Sprintf("%d", port), nil) - return err +func WaitForAuthorizationCode(host string, port int) (string, error) { + var code string + s := &http.Server{ + Addr: fmt.Sprintf("%s:%d", host, port), + } + http.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) { + // get the code from the OIDC provider + if r != nil { + code = r.URL.Query().Get("code") + fmt.Printf("Authorization Code: %v\n", code) + } + s.Close() + + }) + return code, s.ListenAndServe() } -func getAuthorizationCode(w http.ResponseWriter, r *http.Request) { - fmt.Printf("response from OIDC provider: %v\n", r) +func FetchToken(code string, remoteUrl string, clientId string, clientSecret string, state string, redirectUri []string) (string, error) { + var token string + data := url.Values{ + "grant_type": {"authorization_code"}, + "code": {code}, + "client_id": {clientId}, + "client_secret": {clientSecret}, + "state": {state}, + "redirect_uri": {util.EncodeURL(strings.Join(redirectUri, ","))}, + } + res, err := http.PostForm(remoteUrl, data) + if err != nil { + fmt.Printf("failed to get token: %s\n", err) + os.Exit(1) + } + + fmt.Printf("request URL: %s\n", remoteUrl) + fmt.Printf("token response: %v\n", res) + return token, nil } diff --git a/internal/util/util.go b/internal/util/util.go index 56a74e3..92b6eb7 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -1,18 +1,19 @@ package util import ( + "encoding/base64" "math/rand" + "net/url" "strings" ) -const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" -const ( - letterIdxBits = 6 // 6 bits to represent a letter index - letterIdxMask = 1<= 0; { @@ -31,9 +32,17 @@ func RandomString(n int) string { } func BuildAuthorizationUrl(authEndpoint string, clientId string, redirectUri []string, state string, responseType string, scope []string) string { - return authEndpoint + "?" + "cilent_id=" + clientId + - "&redirect_url=" + strings.Join(redirectUri, ",") + + return authEndpoint + "?" + "client_id=" + clientId + + "&redirect_uri=" + EncodeURL(strings.Join(redirectUri, ",")) + "&response_type=" + responseType + "&state=" + state + "&scope=" + strings.Join(scope, "+") } + +func EncodeURL(s string) string { + return url.QueryEscape(s) +} + +func EncodeBase64(s string) string { + return base64.StdEncoding.EncodeToString([]byte(s)) +}