-
Notifications
You must be signed in to change notification settings - Fork 29
feat(copy): use rsync by default with scp fallback #297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -23,7 +23,7 @@ import ( | |||||
| ) | ||||||
|
|
||||||
| var ( | ||||||
| copyLong = "Copy files and directories between your local machine and remote instance" | ||||||
| copyLong = "Copy files and directories between your local machine and remote instance (uses rsync by default and falls back to scp)" | ||||||
| copyExample = "brev copy instance_name:/path/to/remote/file /path/to/local/file\nbrev copy /path/to/local/file instance_name:/path/to/remote/file\nbrev copy ./local-directory/ instance_name:/remote/path/" | ||||||
| ) | ||||||
|
|
||||||
|
|
@@ -87,7 +87,7 @@ func runCopyCommand(t *terminal.Terminal, cstore CopyStore, source, dest string, | |||||
|
|
||||||
| _ = writeconnectionevent.WriteWCEOnEnv(cstore, workspace.DNS) | ||||||
|
|
||||||
| err = runSCP(t, sshName, localPath, remotePath, isUpload) | ||||||
| err = runCopyWithFallback(t, sshName, localPath, remotePath, isUpload) | ||||||
| if err != nil { | ||||||
| return breverrors.WrapAndTrace(err) | ||||||
| } | ||||||
|
|
@@ -202,33 +202,23 @@ func parseWorkspacePath(path string) (workspace, filePath string, err error) { | |||||
| return parts[0], parts[1], nil | ||||||
| } | ||||||
|
|
||||||
| func runSCP(t *terminal.Terminal, sshAlias, localPath, remotePath string, isUpload bool) error { | ||||||
| var scpCmd *exec.Cmd | ||||||
| var source, dest string | ||||||
| type commandRunner func(name string, args ...string) ([]byte, error) | ||||||
|
|
||||||
| startTime := time.Now() | ||||||
| func combinedOutputRunner(name string, args ...string) ([]byte, error) { | ||||||
| cmd := exec.Command(name, args...) //nolint:gosec | ||||||
| return cmd.CombinedOutput() | ||||||
| } | ||||||
|
|
||||||
| scpArgs := []string{"scp"} | ||||||
| func runCopyWithFallback(t *terminal.Terminal, sshAlias, localPath, remotePath string, isUpload bool) error { | ||||||
| source, dest := transferEndpoints(sshAlias, localPath, remotePath, isUpload) | ||||||
|
|
||||||
| if isUpload { | ||||||
| if isDirectory(localPath) { | ||||||
| scpArgs = append(scpArgs, "-r") | ||||||
| } | ||||||
| scpArgs = append(scpArgs, localPath, fmt.Sprintf("%s:%s", sshAlias, remotePath)) | ||||||
| source = localPath | ||||||
| dest = fmt.Sprintf("%s:%s", sshAlias, remotePath) | ||||||
| } else { | ||||||
| scpArgs = append(scpArgs, "-r") | ||||||
| scpArgs = append(scpArgs, fmt.Sprintf("%s:%s", sshAlias, remotePath), localPath) | ||||||
| source = fmt.Sprintf("%s:%s", sshAlias, remotePath) | ||||||
| dest = localPath | ||||||
| startTime := time.Now() | ||||||
| fellBack, err := transferWithFallback(sshAlias, localPath, remotePath, isUpload, combinedOutputRunner) | ||||||
| if fellBack { | ||||||
| t.Vprint(t.Yellow("rsync failed, falling back to scp...\n")) | ||||||
| } | ||||||
|
||||||
|
|
||||||
| scpCmd = exec.Command(scpArgs[0], scpArgs[1:]...) //nolint:gosec //sshAlias is validated workspace identifier | ||||||
|
|
||||||
| output, err := scpCmd.CombinedOutput() | ||||||
| if err != nil { | ||||||
| return breverrors.WrapAndTrace(fmt.Errorf("scp failed: %s\nOutput: %s", err.Error(), string(output))) | ||||||
| return breverrors.WrapAndTrace(err) | ||||||
| } | ||||||
|
|
||||||
| duration := time.Since(startTime) | ||||||
|
|
@@ -238,6 +228,70 @@ func runSCP(t *terminal.Terminal, sshAlias, localPath, remotePath string, isUplo | |||||
| return nil | ||||||
| } | ||||||
|
|
||||||
| func transferWithFallback(sshAlias, localPath, remotePath string, isUpload bool, runner commandRunner) (bool, error) { | ||||||
| err := runRsyncCommand(sshAlias, localPath, remotePath, isUpload, runner) | ||||||
| if err == nil { | ||||||
| return false, nil | ||||||
| } | ||||||
|
|
||||||
| scpErr := runSCPCommand(sshAlias, localPath, remotePath, isUpload, runner) | ||||||
| if scpErr != nil { | ||||||
| return true, fmt.Errorf("rsync failed: %v\nscp fallback failed: %w", err, scpErr) | ||||||
|
||||||
| return true, fmt.Errorf("rsync failed: %v\nscp fallback failed: %w", err, scpErr) | |
| return true, fmt.Errorf("%v\nscp fallback failed: %w", err, scpErr) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| package copy | ||
|
|
||
| import ( | ||
| "errors" | ||
| "os" | ||
| "path/filepath" | ||
| "testing" | ||
|
|
||
| "github.com/stretchr/testify/assert" | ||
| ) | ||
|
|
||
| func TestBuildRsyncArgs(t *testing.T) { | ||
| t.Run("upload file", func(t *testing.T) { | ||
| args := buildRsyncArgs("ws", "/tmp/local.txt", "/remote/path", true) | ||
| assert.Equal(t, []string{"-z", "-e", "ssh", "/tmp/local.txt", "ws:/remote/path"}, args) | ||
| }) | ||
|
|
||
| t.Run("upload directory", func(t *testing.T) { | ||
| tmpDir := t.TempDir() | ||
| localDir := filepath.Join(tmpDir, "mydir") | ||
| err := os.MkdirAll(localDir, 0o755) | ||
| assert.NoError(t, err) | ||
|
|
||
| args := buildRsyncArgs("ws", localDir, "/remote/path", true) | ||
| assert.Equal(t, []string{"-z", "-e", "ssh", "-r", localDir, "ws:/remote/path"}, args) | ||
| }) | ||
|
|
||
| t.Run("download path", func(t *testing.T) { | ||
| args := buildRsyncArgs("ws", "/tmp/local.txt", "/remote/path", false) | ||
| assert.Equal(t, []string{"-z", "-e", "ssh", "-r", "ws:/remote/path", "/tmp/local.txt"}, args) | ||
| }) | ||
| } | ||
|
|
||
| func TestBuildSCPArgs(t *testing.T) { | ||
| t.Run("upload file", func(t *testing.T) { | ||
| args := buildSCPArgs("ws", "/tmp/local.txt", "/remote/path", true) | ||
| assert.Equal(t, []string{"/tmp/local.txt", "ws:/remote/path"}, args) | ||
| }) | ||
|
|
||
| t.Run("upload directory", func(t *testing.T) { | ||
| tmpDir := t.TempDir() | ||
| localDir := filepath.Join(tmpDir, "mydir") | ||
| err := os.MkdirAll(localDir, 0o755) | ||
| assert.NoError(t, err) | ||
|
|
||
| args := buildSCPArgs("ws", localDir, "/remote/path", true) | ||
| assert.Equal(t, []string{"-r", localDir, "ws:/remote/path"}, args) | ||
| }) | ||
|
|
||
| t.Run("download path", func(t *testing.T) { | ||
| args := buildSCPArgs("ws", "/tmp/local.txt", "/remote/path", false) | ||
| assert.Equal(t, []string{"-r", "ws:/remote/path", "/tmp/local.txt"}, args) | ||
| }) | ||
| } | ||
|
|
||
| func TestTransferWithFallback(t *testing.T) { | ||
| t.Run("rsync success", func(t *testing.T) { | ||
| calls := []string{} | ||
| runner := func(name string, args ...string) ([]byte, error) { | ||
| calls = append(calls, name) | ||
| return []byte("ok"), nil | ||
| } | ||
|
|
||
| fellBack, err := transferWithFallback("ws", "/tmp/local.txt", "/remote/path", true, runner) | ||
| assert.NoError(t, err) | ||
| assert.False(t, fellBack) | ||
| assert.Equal(t, []string{"rsync"}, calls) | ||
| }) | ||
|
|
||
| t.Run("rsync fails and scp succeeds", func(t *testing.T) { | ||
| calls := []string{} | ||
| runner := func(name string, args ...string) ([]byte, error) { | ||
| calls = append(calls, name) | ||
| if name == "rsync" { | ||
| return []byte("rsync failed"), errors.New("exit status 1") | ||
| } | ||
| return []byte("scp ok"), nil | ||
| } | ||
|
|
||
| fellBack, err := transferWithFallback("ws", "/tmp/local.txt", "/remote/path", true, runner) | ||
| assert.NoError(t, err) | ||
| assert.True(t, fellBack) | ||
| assert.Equal(t, []string{"rsync", "scp"}, calls) | ||
| }) | ||
|
|
||
| t.Run("rsync fails and scp fails", func(t *testing.T) { | ||
| runner := func(name string, args ...string) ([]byte, error) { | ||
| if name == "rsync" { | ||
| return []byte("rsync output"), errors.New("exit status 1") | ||
| } | ||
| return []byte("scp output"), errors.New("exit status 1") | ||
| } | ||
|
|
||
| fellBack, err := transferWithFallback("ws", "/tmp/local.txt", "/remote/path", true, runner) | ||
| assert.Error(t, err) | ||
| assert.True(t, fellBack) | ||
| assert.Contains(t, err.Error(), "rsync failed") | ||
| assert.Contains(t, err.Error(), "scp fallback failed") | ||
| assert.Contains(t, err.Error(), "rsync output") | ||
| assert.Contains(t, err.Error(), "scp output") | ||
| }) | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
//nolint:gosechere suppresses a security linter finding without any justification. Elsewhere in the codebase,nolint:gosecis consistently annotated with a short reason (e.g.,agentskill.gouses//nolint:gosec // skill files are not sensitive). Add a brief rationale here (and, if applicable, note what input validation makes this safe) to avoid masking real issues in future changes.