diff --git a/internal/client.go b/internal/client.go index 08de8ff..2fa4b21 100644 --- a/internal/client.go +++ b/internal/client.go @@ -21,6 +21,8 @@ type Client struct { Id string `yaml:"id"` Secret string `yaml:"secret"` RedirectUris []string `yaml:"redirect-uris"` + FlowId string + CsrfToken string } func NewClientWithConfig(config *Config) *Client { @@ -33,6 +35,10 @@ func NewClientWithConfig(config *Config) *Client { } } +func (client *Client) IsFlowInitiated() bool { + return client.FlowId != "" +} + func (client *Client) BuildAuthorizationUrl(authEndpoint string, state string, responseType string, scope []string) string { return authEndpoint + "?" + "client_id=" + client.Id + "&redirect_uri=" + util.URLEscape(strings.Join(client.RedirectUris, ",")) + @@ -41,6 +47,80 @@ func (client *Client) BuildAuthorizationUrl(authEndpoint string, state string, r "&scope=" + strings.Join(scope, "+") } +func (client *Client) InitiateLoginFlow(loginUrl string) error { + // kratos: GET /self-service/login/api + req, err := http.NewRequest("GET", loginUrl, bytes.NewBuffer([]byte{})) + if err != nil { + return fmt.Errorf("failed to make request: %v", err) + } + res, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to do request: %v", err) + } + defer res.Body.Close() + + // get the flow ID from response + body, err := io.ReadAll(res.Body) + + var flowData map[string]any + err = json.Unmarshal(body, &flowData) + if err != nil { + return fmt.Errorf("failed to unmarshal flow data: %v\n%v", err, string(body)) + } else { + client.FlowId = flowData["id"].(string) + } + return nil +} + +func (client *Client) FetchFlowData(flowUrl string) (JsonObject, error) { + //kratos: GET /self-service/login/flows?id={flowId} + + // replace {id} in string with actual value + flowUrl = strings.ReplaceAll(flowUrl, "{id}", client.FlowId) + req, err := http.NewRequest("GET", flowUrl, nil) + if err != nil { + return nil, fmt.Errorf("failed to make request: %v", err) + } + res, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to do request: %v", err) + } + defer res.Body.Close() + + // get the flow data from response + body, err := io.ReadAll(res.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %v", err) + } + + var flowData JsonObject + err = json.Unmarshal(body, &flowData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal flow data: %v", err) + } + return flowData, nil +} + +func (client *Client) FetchCSRFToken(flowUrl string) error { + data, err := client.FetchFlowData(flowUrl) + if err != nil { + return fmt.Errorf("failed to fetch flow data: %v", err) + } + + // iterate through nodes and extract the CSRF token attribute from the flow data + ui := data["ui"].(map[string]any) + nodes := ui["nodes"].([]any) + for _, node := range nodes { + attrs := node.(map[string]any)["attributes"].(map[string]any) + name := attrs["name"].(string) + if name == "csrf_token" { + client.CsrfToken = attrs["value"].(string) + return nil + } + } + return fmt.Errorf("failed to extract CSRF token: not found") +} + func (client *Client) FetchTokenFromAuthenticationServer(code string, remoteUrl string, state string) ([]byte, error) { data := url.Values{ "grant_type": {"authorization_code"}, @@ -56,14 +136,19 @@ func (client *Client) FetchTokenFromAuthenticationServer(code string, remoteUrl } defer res.Body.Close() + domain, _ := url.Parse("http://127.0.0.1") + client.Jar.SetCookies(domain, res.Cookies()) + return io.ReadAll(res.Body) } func (client *Client) FetchTokenFromAuthorizationServer(remoteUrl string, jwt string, scope []string) ([]byte, error) { // hydra endpoint: /oauth/token data := "grant_type=" + util.URLEscape("urn:ietf:params:oauth:grant-type:jwt-bearer") + - "&assertion=" + jwt + - "&scope=" + strings.Join(scope, "+") + "&client_id=" + client.Id + + "&client_secret=" + client.Secret + + "&scope=" + strings.Join(scope, "+") + + "&assertion=" + jwt fmt.Printf("encoded params: %v\n\n", data) req, err := http.NewRequest("POST", remoteUrl, bytes.NewBuffer([]byte(data))) req.Header.Add("Content-Type", "application/x-www-form-urlencoded") @@ -76,11 +161,14 @@ func (client *Client) FetchTokenFromAuthorizationServer(remoteUrl string, jwt st } defer res.Body.Close() + // set flow ID back to empty string to indicate a completed flow + client.FlowId = "" + return io.ReadAll(res.Body) } func (client *Client) AddTrustedIssuer(remoteUrl string, idp *oidc.IdentityProvider, subject string, duration time.Duration, scope []string) ([]byte, error) { - // hydra endpoint: /admin/trust/grants/jwt-bearer/issuers + // hydra endpoint: POST /admin/trust/grants/jwt-bearer/issuers if idp == nil { return nil, fmt.Errorf("identity provided is nil") } @@ -88,14 +176,20 @@ func (client *Client) AddTrustedIssuer(remoteUrl string, idp *oidc.IdentityProvi if err != nil { return nil, fmt.Errorf("failed to marshal JWK: %v", err) } + quotedScopes := make([]string, len(scope)) + for i, s := range scope { + quotedScopes[i] = fmt.Sprintf("\"%s\"", s) + } + // NOTE: Can also include "jwks_uri" instead data := []byte(fmt.Sprintf(`{ - "allow_any_subject": true, + "allow_any_subject": false, "issuer": "%s", - "subject": "%s" - "expires_at": "%v" + "subject": "%s", + "expires_at": "%v", "jwk": %v, - "scope": [ %s ], - }`, idp.Issuer, subject, time.Now().Add(duration), string(jwkstr), strings.Join(scope, ","))) + "scope": [ %s ] + }`, idp.Issuer, subject, time.Now().Add(duration).Format(time.RFC3339), string(jwkstr), strings.Join(quotedScopes, ","))) + fmt.Printf("%v\n", string(data)) req, err := http.NewRequest("POST", remoteUrl, bytes.NewBuffer(data)) // req.Header.Add("X-CSRF-Token", client.CsrfToken.Value) @@ -113,6 +207,32 @@ func (client *Client) AddTrustedIssuer(remoteUrl string, idp *oidc.IdentityProvi return io.ReadAll(res.Body) } +func (client *Client) RegisterOAuthClient(registerUrl string) ([]byte, error) { + // hydra endpoint: POST /clients + data := []byte(fmt.Sprintf(`{ + "client_name": "%s", + "client_secret": "%s", + "token_endpoint_auth_method": "client_secret_post", + "scope": "openid email profile", + "grant_types": ["client_credentials", "urn:ietf:params:oauth:grant-type:jwt-bearer"], + "response_types": ["token"] + }`, client.Id, client.Secret)) + + req, err := http.NewRequest("POST", registerUrl, bytes.NewBuffer(data)) + if err != nil { + return nil, fmt.Errorf("failed to make request: %v", err) + } + req.Header.Add("Content-Type", "application/json") + // req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", idToken)) + res, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to do request: %v", err) + } + defer res.Body.Close() + + return io.ReadAll(res.Body) +} + func (client *Client) CreateIdentity(remoteUrl string, idToken string) ([]byte, error) { // kratos endpoint: /admin/identities data := []byte(`{ @@ -150,3 +270,8 @@ func (client *Client) FetchIdentities(remoteUrl string) ([]byte, error) { return io.ReadAll(res.Body) } + +func (client *Client) ClearCookies() { + jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) + client.Jar = jar +} diff --git a/internal/config.go b/internal/config.go index e507c2f..b9c3946 100644 --- a/internal/config.go +++ b/internal/config.go @@ -22,6 +22,7 @@ type Config struct { OpenBrowser bool `yaml:"open-browser"` DecodeIdToken bool `yaml:"decode-id-token"` DecodeAccessToken bool `yaml:"decode-access-token"` + RunOnce bool `yaml:"run-once"` } func NewConfig() Config { @@ -49,6 +50,7 @@ func NewConfig() Config { OpenBrowser: false, DecodeIdToken: false, DecodeAccessToken: false, + RunOnce: true, } } diff --git a/internal/login.go b/internal/login.go index 28a8a2b..d4e3f12 100644 --- a/internal/login.go +++ b/internal/login.go @@ -19,6 +19,18 @@ func Login(config *Config) error { server := NewServerWithConfig(config) client := NewClientWithConfig(config) + // initiate the login flow and get a flow ID and CSRF token + { + err := client.InitiateLoginFlow(config.ActionUrls.Login) + if err != nil { + return fmt.Errorf("failed to initiate login flow: %v", err) + } + err = client.FetchCSRFToken(config.ActionUrls.LoginFlowId) + if err != nil { + return fmt.Errorf("failed to fetch CSRF token: %v", err) + } + } + // try and fetch server configuration if provided URL idp := oidc.NewIdentityProvider() if config.ActionUrls.ServerConfig != "" { @@ -35,7 +47,7 @@ func Login(config *Config) error { } // check if all appropriate parameters are set in config - if !hasRequiredParams(config) { + if !HasRequiredParams(config) { return fmt.Errorf("client ID must be set") } @@ -70,6 +82,24 @@ func Login(config *Config) error { fmt.Printf("client did not initialize\n") } + // start up another serve in background to listen for success or failures + d := make(chan []byte) + quit := make(chan bool) + var access_token []byte + go server.Serve(d) + go func() { + select { + case <-d: + fmt.Printf("got access token") + quit <- true + case <-quit: + close(d) + close(quit) + return + default: + } + }() + // use code from response and exchange for bearer token (with ID token) tokenString, err := client.FetchTokenFromAuthenticationServer( code, @@ -93,6 +123,7 @@ func Login(config *Config) error { if err != nil { fmt.Printf("failed to parse ID token: %v\n", err) } else { + fmt.Printf("token: %v\n", idToken) if config.DecodeIdToken { if err != nil { fmt.Printf("failed to decode JWT: %v\n", err) @@ -117,6 +148,18 @@ func Login(config *Config) error { // } // } + // extract the scope from access token claims + // var scope []string + // var accessJsonPayload map[string]any + // var accessJwtPayload []byte = accessJwtSegments[1] + // if accessJsonPayload != nil { + // err := json.Unmarshal(accessJwtPayload, &accessJsonPayload) + // if err != nil { + // return fmt.Errorf("failed to unmarshal JWT: %v", err) + // } + // scope = idJsonPayload["scope"].([]string) + // } + // create a new identity with identity and session manager if url is provided if config.ActionUrls.Identities != "" { fmt.Printf("Attempting to create a new identity...\n") @@ -145,44 +188,47 @@ func Login(config *Config) error { return fmt.Errorf("failed to extract subject from ID token claims") } - // extract the scope from access token claims - // var scope []string - // var accessJsonPayload map[string]any - // var accessJwtPayload []byte = accessJwtSegments[1] - // if accessJsonPayload != nil { - // err := json.Unmarshal(accessJwtPayload, &accessJsonPayload) - // if err != nil { - // return fmt.Errorf("failed to unmarshal JWT: %v", err) - // } - // scope = idJsonPayload["scope"].([]string) - // } - // fetch JWKS and add issuer to authentication server to submit ID token fmt.Printf("Fetching JWKS from authentication server for verification...\n") err = idp.FetchJwk(config.ActionUrls.JwksUri) if err != nil { - return fmt.Errorf("failed to fetch JWK: %v\n", err) + return fmt.Errorf("failed to fetch JWK: %v", err) } else { fmt.Printf("Attempting to add issuer to authorization server...\n") res, err := client.AddTrustedIssuer(config.ActionUrls.TrustedIssuers, idp, subject, time.Duration(1000), config.Scope) if err != nil { return fmt.Errorf("failed to add trusted issuer: %v", err) } - if string(res) == "" { - fmt.Printf("Added issuer to authorization server successfully.\n") - } + fmt.Printf("%v\n", string(res)) + } + + // try and register a new client with authorization server + res, err := client.RegisterOAuthClient("http://127.0.0.1:4445/clients") + if err != nil { + return fmt.Errorf("failed to register client: %v", err) + } + fmt.Printf("%v\n", string(res)) + + // extract the client info from response + var clientData map[string]any + err = json.Unmarshal(res, &clientData) + if err != nil { + return fmt.Errorf("failed to unmarshal client data: %v", err) + } else { + client.Id = clientData["client_id"].(string) + client.Secret = clientData["client_secret"].(string) } // use ID token/user info to fetch access token from authentication server if config.ActionUrls.AccessToken != "" { fmt.Printf("Fetching access token from authorization server...\n") - accessToken, err := client.FetchTokenFromAuthorizationServer(config.ActionUrls.AccessToken, idToken, config.Scope) + res, err := client.FetchTokenFromAuthorizationServer(config.ActionUrls.AccessToken, idToken, config.Scope) if err != nil { return fmt.Errorf("failed to fetch access token: %v", err) } - fmt.Printf("%s\n", accessToken) + fmt.Printf("%s\n", res) } - fmt.Printf("Success!") + d <- access_token return nil } diff --git a/internal/opaal.go b/internal/opaal.go index b82cbea..67dcf81 100644 --- a/internal/opaal.go +++ b/internal/opaal.go @@ -6,8 +6,10 @@ type ActionUrls struct { AccessToken string `yaml:"access-token"` ServerConfig string `yaml:"server-config"` JwksUri string `yaml:"jwks_uri"` + Login string `yaml:"login"` + LoginFlowId string `yaml:"login-flow-id"` } -func hasRequiredParams(config *Config) bool { +func HasRequiredParams(config *Config) bool { return config.Client.Id != "" && config.Client.Secret != "" } diff --git a/internal/server.go b/internal/server.go index 3936b00..c048592 100644 --- a/internal/server.go +++ b/internal/server.go @@ -5,10 +5,13 @@ import ( "net/http" "os" "strings" + + "github.com/go-chi/chi/middleware" + "github.com/go-chi/chi/v5" ) type Server struct { - http.Server + *http.Server Host string `yaml:"host"` Port int `yaml:"port"` } @@ -17,16 +20,16 @@ func NewServerWithConfig(config *Config) *Server { host := config.Server.Host port := config.Server.Port server := &Server{ + Server: &http.Server{ + Addr: fmt.Sprintf("%s:%d", host, port), + }, Host: host, Port: port, } - server.Addr = fmt.Sprintf("%s:%d", host, port) return server } func (s *Server) SetListenAddr(host string, port int) { - s.Host = host - s.Port = port s.Addr = s.GetListenAddr() } @@ -36,34 +39,69 @@ func (s *Server) GetListenAddr() string { func (s *Server) WaitForAuthorizationCode(loginUrl string) (string, error) { var code string - http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + r := chi.NewRouter() + r.Use(middleware.RedirectSlashes) + r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/login", http.StatusSeeOther) }) - http.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) { + r.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) { // show login page with notice to redirect loginPage, err := os.ReadFile("pages/index.html") if err != nil { fmt.Printf("failed to load login page: %v\n", err) } loginPage = []byte(strings.ReplaceAll(string(loginPage), "{{loginUrl}}", loginUrl)) - w.WriteHeader(http.StatusSeeOther) w.Write(loginPage) }) - http.HandleFunc("/oidc/callback", func(w http.ResponseWriter, r *http.Request) { + r.HandleFunc("/oidc/callback", func(w http.ResponseWriter, r *http.Request) { // get the code from the OIDC provider if r != nil { code = r.URL.Query().Get("code") fmt.Printf("Authorization code: %v\n", code) } - http.Redirect(w, r, s.Addr+"/success", http.StatusSeeOther) - s.Close() + http.Redirect(w, r, "/redirect", http.StatusSeeOther) }) + r.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) { + err := s.Close() + if err != nil { + fmt.Printf("failed to close server: %v\n", err) + } + }) + s.Handler = r + return code, s.ListenAndServe() } -func (s *Server) ShowSuccessPage() error { - http.HandleFunc("/success", func(w http.ResponseWriter, r *http.Request) { +func (s *Server) Serve(data chan []byte) error { + output, ok := <-data + if !ok { + return fmt.Errorf("failed to receive data") + } + fmt.Printf("Received data: %v\n", string(output)) + // http.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) { + + // }) + r := chi.NewRouter() + r.HandleFunc("/success", func(w http.ResponseWriter, r *http.Request) { + fmt.Printf("Serving success page.") + successPage, err := os.ReadFile("pages/success.html") + if err != nil { + fmt.Printf("failed to load success page: %v\n", err) + } + successPage = []byte(strings.ReplaceAll(string(successPage), "{{access_token}}", string(output))) + w.Write(successPage) }) + r.HandleFunc("/error", func(w http.ResponseWriter, r *http.Request) { + fmt.Printf("Serving error page.") + errorPage, err := os.ReadFile("pages/success.html") + if err != nil { + fmt.Printf("failed to load success page: %v\n", err) + } + // errorPage = []byte(strings.ReplaceAll(string(errorPage), "{{access_token}}", output)) + w.Write(errorPage) + }) + + s.Handler = r return s.ListenAndServe() } diff --git a/internal/util/util.go b/internal/util/util.go index 3a28bc5..ff08438 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -92,3 +92,11 @@ func GetCommit() string { } return string(bytes) } + +func Tokenize(s string) map[string]any { + tokens := make(map[string]any) + + // find token enclosed in curly brackets + + return tokens +}