From e5c1b59bc18b0f1c8198241c4778e0dce78d7d44 Mon Sep 17 00:00:00 2001 From: David Allen Date: Fri, 29 Aug 2025 11:04:16 -0600 Subject: [PATCH] feat: allow expanding and remove archive after download --- cmd/download.go | 161 ++++++++++++++++++++++++++++++++---- internal/archive/archive.go | 81 +++++++++++++++++- 2 files changed, 223 insertions(+), 19 deletions(-) diff --git a/cmd/download.go b/cmd/download.go index 7d00d01..3bc502f 100644 --- a/cmd/download.go +++ b/cmd/download.go @@ -5,8 +5,10 @@ import ( "net/http" "net/url" "os" + "path/filepath" "strings" + "git.towk2.me/towk/makeshift/internal/archive" "git.towk2.me/towk/makeshift/pkg/client" "github.com/rs/zerolog/log" "github.com/spf13/cobra" @@ -37,11 +39,13 @@ var downloadCmd = cobra.Command{ }, Run: func(cmd *cobra.Command, args []string) { var ( - host, _ = cmd.Flags().GetString("host") - path, _ = cmd.Flags().GetString("path") - outputPath, _ = cmd.Flags().GetString("output") - pluginNames, _ = cmd.Flags().GetStringSlice("plugins") - profileIDs, _ = cmd.Flags().GetStringSlice("profiles") + host, _ = cmd.Flags().GetString("host") + path, _ = cmd.Flags().GetString("path") + outputPath, _ = cmd.Flags().GetString("output") + pluginNames, _ = cmd.Flags().GetStringSlice("plugins") + profileIDs, _ = cmd.Flags().GetStringSlice("profiles") + extract, _ = cmd.Flags().GetBool("extract") + removeArchive, _ = cmd.Flags().GetBool("remove-archive") c = client.New(host) res *http.Response @@ -92,15 +96,6 @@ var downloadCmd = cobra.Command{ os.Exit(1) } - // helper to write downloaded files - var writeFiles = func(path string, body []byte) { - err = os.WriteFile(outputPath, body, 0o755) - if err != nil { - log.Error().Err(err).Msg("failed to write file(s) from download") - os.Exit(1) - } - } - // determine if output path is an archive or file switch res.Header.Get("FILETYPE") { case "archive": @@ -113,6 +108,36 @@ var downloadCmd = cobra.Command{ writeFiles(outputPath, body) log.Debug().Str("path", outputPath).Msg("wrote archive to specified path") } + + // extract files if '-x' flag is passed + if extract { + var ( + dir = filepath.Dir(outputPath) + base = strings.TrimSuffix(filepath.Base(outputPath), ".tar.gz") + ) + err = archive.Expand(outputPath, fmt.Sprintf("%s/%s", dir, base)) + if err != nil { + log.Error().Err(err). + Str("path", outputPath). + Msg("failed to expand archive") + os.Exit(1) + } + } + + // optionally, remove archive if '-r' flag is passed + // NOTE: this can only be used if `-x` flag is set + if removeArchive { + if !extract { + log.Warn().Msg("requires '-x/--extract' flag to be set to 'true'") + } else { + err = os.Remove(outputPath) + if err != nil { + log.Error().Err(err). + Str("path", outputPath). + Msg("failed to remove archive") + } + } + } case "file": // write to file if '-o' specified otherwise stdout if outputPath != "" { @@ -126,18 +151,111 @@ var downloadCmd = cobra.Command{ } var downloadProfileCmd = &cobra.Command{ - Use: "profile", + Use: "profile", + Example: ` + // download a profile + makeshift download profile default +`, + Args: cobra.ExactArgs(1), Short: "Download a profile", + PreRun: func(cmd *cobra.Command, args []string) { + setenv(cmd, "host", "MAKESHIFT_HOST") + setenv(cmd, "path", "MAKESHIFT_PATH") + }, Run: func(cmd *cobra.Command, args []string) { + var ( + host, _ = cmd.Flags().GetString("host") + outputPath, _ = cmd.Flags().GetString("output") + c = client.New(host) + res *http.Response + body []byte + query string + err error + ) + for _, profileID := range args { + query = fmt.Sprintf("/profile/{%s}", profileID) + res, body, err = c.MakeRequest(client.HTTPEnvelope{ + Path: query, + Method: http.MethodGet, + }) + if err != nil { + log.Error().Err(err). + Str("host", host). + Msg("failed to make request") + os.Exit(1) + } + if res.StatusCode != http.StatusOK { + log.Error(). + Any("status", map[string]any{ + "code": res.StatusCode, + "message": res.Status, + }). + Str("host", host). + Msg("response returned bad status") + os.Exit(1) + } + if outputPath != "" { + writeFiles(outputPath, body) + } else { + fmt.Println(string(body)) + } + } }, } var downloadPluginCmd = &cobra.Command{ - Use: "plugin", + Use: "plugin", + Example: ` + // download a plugin + makeshift download plugin smd jinja2 +`, + Args: cobra.ExactArgs(1), Short: "Download a plugin", + PreRun: func(cmd *cobra.Command, args []string) { + setenv(cmd, "host", "MAKESHIFT_HOST") + setenv(cmd, "path", "MAKESHIFT_PATH") + }, Run: func(cmd *cobra.Command, args []string) { + var ( + host, _ = cmd.Flags().GetString("host") + outputPath, _ = cmd.Flags().GetString("output") + c = client.New(host) + res *http.Response + query string + body []byte + err error + ) + for _, pluginName := range args { + + query = fmt.Sprintf("/profile/%s?", pluginName) + res, body, err = c.MakeRequest(client.HTTPEnvelope{ + Path: query, + Method: http.MethodGet, + }) + if err != nil { + log.Error().Err(err). + Str("host", host). + Msg("failed to make request") + os.Exit(1) + } + if res.StatusCode != http.StatusOK { + log.Error(). + Any("status", map[string]any{ + "code": res.StatusCode, + "message": res.Status, + }). + Str("host", host). + Msg("response returned bad status") + os.Exit(1) + } + if outputPath != "" { + writeFiles(outputPath, body) + } else { + writeFiles(fmt.Sprintf("%s.so", pluginName), body) + } + } }, } @@ -150,9 +268,16 @@ func init() { downloadCmd.Flags().BoolP("extract", "x", false, "Set whether to extract archive locally after downloading") downloadCmd.Flags().BoolP("remove-archive", "r", false, "Set whether to remove the archive after extracting (used with '--extract' flag)") - downloadCmd.MarkFlagsRequiredTogether("remove-archive", "extract") - downloadCmd.AddCommand(downloadProfileCmd, downloadPluginCmd) rootCmd.AddCommand(&downloadCmd) } + +// helper to write downloaded files +func writeFiles(path string, body []byte) { + var err = os.WriteFile(path, body, 0o755) + if err != nil { + log.Error().Err(err).Msg("failed to write file(s) from download") + os.Exit(1) + } +} diff --git a/internal/archive/archive.go b/internal/archive/archive.go index 5145afe..eaf6412 100644 --- a/internal/archive/archive.go +++ b/internal/archive/archive.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "os" + "path/filepath" "strings" makeshift "git.towk2.me/towk/makeshift/pkg" @@ -32,7 +33,85 @@ func Create(filenames []string, buf io.Writer, hooks []makeshift.Hook) error { return nil } -func Expand(path string) error { +func Expand(tarname, xpath string) error { + tarfile, err := os.Open(tarname) + if err != nil { + return err + } + defer tarfile.Close() + // absPath, err := filepath.Abs(xpath) + // if err != nil { + // return err + // } + tr := tar.NewReader(tarfile) + if strings.HasSuffix(tarname, ".gz") { + gz, err := gzip.NewReader(tarfile) + if err != nil { + return fmt.Errorf("failed to create new gzip reader: %v", err) + } + defer gz.Close() + tr = tar.NewReader(gz) + } + + // untar each segment + for { + header, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("failed to get next tar header: %v", err) + } + + // determine proper file path info + var ( + fileinfo = header.FileInfo() + filename = header.Name + file *os.File + abspath string + dirpath string + ) + + // absFileName := filepath.Join(absPath, filename) // if a dir, create it, then go to next segment + if fileinfo.Mode().IsDir() { + if err := os.MkdirAll(filename, 0o755); err != nil { + return fmt.Errorf("failed to make directory '%s': %v", filename, err) + } + continue + } + + dirpath = filepath.Dir(filename) + if err = os.MkdirAll(dirpath, 0o777); err != nil { + return fmt.Errorf("failed to make directory '%s': %v", err) + } + + // create new file with original file mode + abspath, err = filepath.Abs(filename) + if err != nil { + return fmt.Errorf("failed to get absolute path: %v", err) + } + file, err = os.OpenFile( + abspath, + os.O_RDWR|os.O_CREATE|os.O_TRUNC, + fileinfo.Mode().Perm(), + ) + if err != nil { + return fmt.Errorf("failed to open file: %v", err) + } + // fmt.Printf("x %s\n", filename) + + // copy the contents to the new file + n, err := io.Copy(file, tr) + if err != nil { + return fmt.Errorf("failed to copy file: %v", err) + } + if err = file.Close(); err != nil { + return fmt.Errorf("failed to close file: %v", err) + } + if n != fileinfo.Size() { + return fmt.Errorf("wrote %d, want %d", n, fileinfo.Size()) + } + } return nil }