diff --git a/cmd/serve.go b/cmd/serve.go index e150b34..7eeb459 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -51,6 +51,7 @@ var serveCmd = &cobra.Command{ // set up the routes and start the server server := server.Server{ + Config: &config, Server: &http.Server{ Addr: fmt.Sprintf("%s:%d", config.Server.Host, config.Server.Port), }, @@ -65,7 +66,7 @@ var serveCmd = &cobra.Command{ Verbose: verbose, }, } - err := server.Serve(&config) + err := server.Serve() if errors.Is(err, http.ErrServerClosed) { fmt.Printf("Server closed.") } else if err != nil { diff --git a/internal/server/server.go b/internal/server/server.go index bf7534d..5c16ca6 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -27,6 +27,7 @@ type Jwks struct { } type Server struct { *http.Server + Config *configurator.Config Jwks Jwks `yaml:"jwks"` GeneratorParams generator.Params TokenAuth *jwtauth.JWTAuth @@ -44,21 +45,21 @@ func New() *Server { } } -func (s *Server) Serve(config *configurator.Config) error { +func (s *Server) Serve() error { // create client just for the server to use to fetch data from SMD _ = &configurator.SmdClient{ - Host: config.SmdClient.Host, - Port: config.SmdClient.Port, + Host: s.Config.SmdClient.Host, + Port: s.Config.SmdClient.Port, } // set the server address with config values - s.Server.Addr = fmt.Sprintf("%s:%d", config.Server.Host, config.Server.Port) + s.Server.Addr = fmt.Sprintf("%s:%d", s.Config.Server.Host, s.Config.Server.Port) // fetch JWKS public key from authorization server - if config.Server.Jwks.Uri != "" && tokenAuth == nil { - for i := 0; i < config.Server.Jwks.Retries; i++ { + if s.Config.Server.Jwks.Uri != "" && tokenAuth == nil { + for i := 0; i < s.Config.Server.Jwks.Retries; i++ { var err error - tokenAuth, err = configurator.FetchPublicKeyFromURL(config.Server.Jwks.Uri) + tokenAuth, err = configurator.FetchPublicKeyFromURL(s.Config.Server.Jwks.Uri) if err != nil { logrus.Errorf("failed to fetch JWKS: %w", err) continue @@ -67,55 +68,75 @@ func (s *Server) Serve(config *configurator.Config) error { } } - var WriteError = func(w http.ResponseWriter, format string, a ...any) { - errmsg := fmt.Sprintf(format, a...) - fmt.Printf(errmsg) - w.Write([]byte(errmsg)) - } - // create new go-chi router with its routes router := chi.NewRouter() - router.Use(middleware.RedirectSlashes) + router.Use(middleware.RequestID) + router.Use(middleware.RealIP) + router.Use(middleware.Logger) + router.Use(middleware.Recoverer) + router.Use(middleware.StripSlashes) router.Use(middleware.Timeout(60 * time.Second)) - router.Group(func(r chi.Router) { - if config.Server.Jwks.Uri != "" { + if s.Config.Server.Jwks.Uri != "" { + router.Group(func(r chi.Router) { r.Use( jwtauth.Verifier(tokenAuth), jwtauth.Authenticator(tokenAuth), ) - } - r.HandleFunc("/generate", func(w http.ResponseWriter, r *http.Request) { - s.GeneratorParams.Target = r.URL.Query().Get("target") - outputs, err := generator.Generate(config, s.GeneratorParams) - if err != nil { - WriteError(w, "failed to generate config: %v", err) - return - } - // convert byte arrays to string - tmp := map[string]string{} - for path, output := range outputs { - tmp[path] = string(output) - } - - // marshal output to JSON then send - b, err := json.Marshal(tmp) - if err != nil { - WriteError(w, "failed to marshal output: %v", err) - return - } - _, err = w.Write(b) - if err != nil { - WriteError(w, "failed to write response: %v", err) - return - } + // protected routes if using auth + r.HandleFunc("/generate", s.Generate) + r.HandleFunc("/templates", s.ManageTemplates) }) - r.HandleFunc("/templates", func(w http.ResponseWriter, r *http.Request) { - // TODO: handle GET request - // TODO: handle POST request + } else { + // public routes without auth + router.HandleFunc("/generate", s.Generate) + router.HandleFunc("/templates", s.ManageTemplates) + } + + // always public routes go here (none at the moment) - }) - }) s.Handler = router return s.ListenAndServe() } + +func WriteError(w http.ResponseWriter, format string, a ...any) { + errmsg := fmt.Sprintf(format, a...) + fmt.Printf(errmsg) + w.Write([]byte(errmsg)) +} + +func (s *Server) Generate(w http.ResponseWriter, r *http.Request) { + s.GeneratorParams.Target = r.URL.Query().Get("target") + outputs, err := generator.Generate(s.Config, s.GeneratorParams) + if err != nil { + WriteError(w, "failed to generate config: %v", err) + return + } + + // convert byte arrays to string + tmp := map[string]string{} + for path, output := range outputs { + tmp[path] = string(output) + } + + // marshal output to JSON then send + b, err := json.Marshal(tmp) + if err != nil { + WriteError(w, "failed to marshal output: %v", err) + return + } + _, err = w.Write(b) + if err != nil { + WriteError(w, "failed to write response: %v", err) + return + } +} + +func (s *Server) ManageTemplates(w http.ResponseWriter, r *http.Request) { + // TODO: need to implement template managing API first in "internal/generator/templates" or something + _, err := w.Write([]byte("this is not implemented yet")) + if err != nil { + WriteError(w, "failed to write response: %v", err) + return + } +}