536 lines
12 KiB
Go
536 lines
12 KiB
Go
// Send and receive files via SFTP using multiple download streams concurrently (for downloads).
|
|
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"errors"
|
|
"github.com/pkg/sftp"
|
|
"golang.org/x/crypto/ssh"
|
|
"io"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
"strconv"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
// NoProgress no progress for current file
|
|
NoProgress int64 = -1
|
|
)
|
|
|
|
// Sync Keeps track of the connection information.
|
|
type Sync struct {
|
|
sshClient *ssh.Client
|
|
sftpClients []*sftp.Client
|
|
sftpClientCount int
|
|
}
|
|
|
|
// NewSync Create a new Sync object, connect to the SSH server, and create sftp clients
|
|
func NewSync(threads string, user string, pass string, server string, port string) (*Sync, error) {
|
|
// convert the threads input to an int
|
|
clientCount, err := strconv.Atoi(threads)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if clientCount < 1 {
|
|
return nil, errors.New("Must have a thread count >= 1")
|
|
}
|
|
|
|
sshClient, err := newSSHClient(user, pass, server, port)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// initialize a total of client_count sftp clients
|
|
sftpClients := make([]*sftp.Client, clientCount)
|
|
for i := 0; i < clientCount; i++ {
|
|
sftpClient, err := sftp.NewClient(sshClient)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sftpClients[i] = sftpClient
|
|
}
|
|
|
|
return &Sync{sshClient, sftpClients, clientCount}, nil
|
|
}
|
|
|
|
// Create a new SSH client instance and confirm that we can make sessions
|
|
func newSSHClient(user string, pass string, server string, port string) (*ssh.Client, error) {
|
|
sshConfig := &ssh.ClientConfig{
|
|
User: user,
|
|
Auth: []ssh.AuthMethod{
|
|
ssh.Password(pass),
|
|
},
|
|
}
|
|
|
|
sshClient, err := ssh.Dial("tcp", server + ":" + port, sshConfig)
|
|
|
|
if err != nil {
|
|
return sshClient, err
|
|
}
|
|
|
|
session, err := sshClient.NewSession()
|
|
if err != nil {
|
|
return sshClient, err
|
|
}
|
|
|
|
defer session.Close()
|
|
|
|
return sshClient, err
|
|
}
|
|
|
|
// SendFiles Send a list of files in the format of {"source_path": "destination"} to the SSH server. This does not handle directories.
|
|
func (s Sync) SendFiles(files map[string]string) error {
|
|
return SendFiles(s.sftpClients[0], files)
|
|
}
|
|
|
|
// SendFiles Send a list of files in the format of {"source_path": "destination"} to the SSH server. This does not handle directories.
|
|
func SendFiles(sftpClient *sftp.Client, files map[string]string) error {
|
|
for sourceFile, destinationFile := range files {
|
|
// 512KB buffer for reading/sending data
|
|
data := make([]byte, 524288)
|
|
|
|
// Open file that we will be sending
|
|
sourceData, err := os.Open(sourceFile)
|
|
if err != nil {
|
|
log.Println("SendFiles: Failed to open source file " + sourceFile)
|
|
return err
|
|
}
|
|
|
|
// Get the info of the file that we will be sending
|
|
sourceStat, err := sourceData.Stat()
|
|
if err != nil {
|
|
log.Println("SendFiles: Failed to stat source file " + sourceFile)
|
|
return err
|
|
}
|
|
|
|
// Extract the size of the file that we will be sending
|
|
sourceSize := sourceStat.Size()
|
|
// Create the destination file for the source file we're sending
|
|
newFile, err := sftpClient.Create(destinationFile)
|
|
if err != nil {
|
|
log.Println("SendFiles: Failed to create destination file " + destinationFile)
|
|
return err
|
|
}
|
|
|
|
// Track our position in reading/writing the file
|
|
var currentPosition int64
|
|
for currentPosition < sourceSize {
|
|
// If the next iteration will be greater than the file size, reduce to the data size
|
|
if currentPosition + int64(len(data)) > sourceSize {
|
|
data = make([]byte, sourceSize - currentPosition)
|
|
}
|
|
|
|
// Read data from the source file
|
|
read, err := sourceData.Read(data)
|
|
if err != nil {
|
|
// If it's the end of the file and we didn't read anything, break
|
|
if err == io.EOF {
|
|
if read == 0 {
|
|
break
|
|
}
|
|
} else {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Write the data from the source file to the destination file
|
|
_, err = newFile.Write(data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Update the current position in the file
|
|
currentPosition += int64(read)
|
|
}
|
|
|
|
// close the source file
|
|
err = sourceData.Close()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// close the destination file
|
|
err = newFile.Close()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetFile Get a file from the source_file path to be stored in destination_file path
|
|
func (s Sync) GetFile(sourceFile string, destinationFile string) error {
|
|
// Store channels for all the concurrent download parts
|
|
channels := make([]chan error, s.sftpClientCount)
|
|
|
|
// Make channels for all the concurrent downloads
|
|
for i := 0; i < s.sftpClientCount; i++ {
|
|
channels[i] = make(chan error)
|
|
}
|
|
|
|
// Start the concurrent downloads
|
|
for i := 0; i < s.sftpClientCount; i++ {
|
|
go GetFile(s.sftpClients[i], sourceFile, destinationFile, i + 1, s.sftpClientCount, channels[i])
|
|
}
|
|
|
|
// Block until all downloads are completed or one errors
|
|
allDone := false
|
|
for !allDone {
|
|
allDone = true
|
|
for i, channel := range channels {
|
|
if channel == nil {
|
|
continue
|
|
}
|
|
|
|
select {
|
|
case err := <- channel:
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
channels[i] = nil
|
|
break
|
|
default:
|
|
// still running
|
|
if allDone {
|
|
allDone = false
|
|
}
|
|
break
|
|
}
|
|
}
|
|
|
|
time.Sleep(time.Second)
|
|
}
|
|
|
|
err := destroyProgress(destinationFile)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetFile Get a file from the source_file path to be stored in destination_file path.
|
|
// worker_number and work_total are not zero indexed, but 1 indexed
|
|
func GetFile(sftpClient *sftp.Client, sourceFile string, destinationFile string, workerNumber int, workerTotal int, com chan <- error) error {
|
|
// Open source_data for reading
|
|
sourceData, err := sftpClient.OpenFile(sourceFile, os.O_RDONLY)
|
|
if err != nil {
|
|
com <- err
|
|
return err
|
|
}
|
|
|
|
// Get info for source_data
|
|
stat, err := sourceData.Stat()
|
|
if err != nil {
|
|
com <- err
|
|
return err
|
|
}
|
|
|
|
// Extract the size of source_data
|
|
statSize := stat.Size()
|
|
|
|
// Calculate which byte to start reading data from
|
|
start, err := getProgress(destinationFile, workerNumber)
|
|
if err != nil {
|
|
com <- err
|
|
return err
|
|
}
|
|
|
|
if start == NoProgress {
|
|
if workerNumber == 1 {
|
|
start = 0
|
|
} else {
|
|
start = (statSize * int64(workerNumber - 1)) / int64(workerTotal)
|
|
}
|
|
}
|
|
|
|
// Calculate which byte to stop reading data from
|
|
var stop int64
|
|
if workerNumber == workerTotal {
|
|
stop = statSize
|
|
} else {
|
|
stop = (statSize * int64(workerNumber)) / int64(workerTotal)
|
|
}
|
|
|
|
// Create the new file for writing
|
|
newFile, err := os.OpenFile(destinationFile, os.O_WRONLY | os.O_CREATE, 0777)
|
|
if err != nil {
|
|
com <- err
|
|
return err
|
|
}
|
|
|
|
// Seek to the computed start point
|
|
offset, err := sourceData.Seek(start, 0)
|
|
if err != nil {
|
|
com <- err
|
|
return err
|
|
}
|
|
|
|
// Seeking messed up real bad
|
|
if offset != start {
|
|
err = errors.New("Returned incorrect offset for source " + sourceFile)
|
|
com <- err
|
|
return err
|
|
}
|
|
|
|
// Seek to the computed start point
|
|
offset, err = newFile.Seek(start, 0)
|
|
if err != nil {
|
|
com <- err
|
|
return err
|
|
}
|
|
|
|
// Seeking messed up real bad
|
|
if offset != start {
|
|
err = errors.New("Return incorrect offset for destination " + destinationFile)
|
|
com <- err
|
|
return err
|
|
}
|
|
|
|
// 512KB chunks
|
|
var dataSize int64 = 524288
|
|
// Change the size if the chunk is larger than the file
|
|
chunkDifference := stop - start
|
|
if chunkDifference < dataSize {
|
|
dataSize = chunkDifference
|
|
}
|
|
|
|
// Initialize the buffer for reading/writing
|
|
data := make([]byte, dataSize)
|
|
var currentSize int64
|
|
for currentSize = start; currentSize < stop; currentSize += dataSize {
|
|
err = updateProgress(destinationFile, currentSize, workerNumber)
|
|
if err != nil {
|
|
com <- err
|
|
return err
|
|
}
|
|
|
|
// Adjust the size of the buffer if the next iteration will be greater than what has yet to be read
|
|
if currentSize + dataSize > stop {
|
|
dataSize = stop - currentSize
|
|
data = make([]byte, dataSize)
|
|
}
|
|
|
|
// Read the chunk
|
|
read, err := sourceData.Read(data)
|
|
if err != nil {
|
|
// Exit the loop if we're at the end of the file and no data was read
|
|
if err == io.EOF {
|
|
if read == 0 {
|
|
break
|
|
}
|
|
} else {
|
|
com <- err
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Write the chunk
|
|
_, err = newFile.Write(data)
|
|
if err != nil {
|
|
com <- err
|
|
return err
|
|
}
|
|
}
|
|
|
|
err = updateProgress(destinationFile, currentSize, workerNumber)
|
|
if err != nil {
|
|
com <- err
|
|
return err
|
|
}
|
|
|
|
// Close out the files
|
|
err = sourceData.Close()
|
|
if err != nil {
|
|
com <- err
|
|
return err
|
|
}
|
|
|
|
err = newFile.Close()
|
|
if err != nil {
|
|
com <- err
|
|
return err
|
|
}
|
|
|
|
com <- nil
|
|
return nil
|
|
}
|
|
|
|
// GetPath Get a given directory or file defined by source_path and save it to destination_path
|
|
func (s Sync) GetPath(sourcePath string, destinationPath string) error {
|
|
// Get all the dirs and files underneath source_path
|
|
dirs, files, err := s.getChildren(sourcePath)
|
|
|
|
// Remove the trailing slash if it exists
|
|
if sourcePath[len(sourcePath) - 1] == '/' {
|
|
sourcePath = sourcePath[:len(sourcePath) - 1]
|
|
}
|
|
|
|
// Get the parent path of source_path
|
|
sourceBase := filepath.Dir(sourcePath)
|
|
sourceBaseLen := len(sourceBase)
|
|
|
|
// Make all the directories in destination_path
|
|
for _, dir := range dirs {
|
|
dir = filepath.Join(destinationPath, filepath.FromSlash(dir[sourceBaseLen:]))
|
|
err = os.MkdirAll(dir, 0777)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Get all the files and place them in destination_path
|
|
for _, file := range files {
|
|
newFile := filepath.Join(destinationPath, filepath.FromSlash(file[sourceBaseLen:]))
|
|
err = s.GetFile(file, newFile)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Get the directories and files underneath a given sftp root path
|
|
func (s Sync) getChildren(root string) ([]string, []string, error) {
|
|
// Used to walk through the path
|
|
walker := s.sftpClients[0].Walk(root)
|
|
|
|
// Keep track of the directories
|
|
var dirs []string
|
|
// Keep track of the files
|
|
var files []string
|
|
|
|
// Walk through the files and directories
|
|
for walker.Step() {
|
|
err := walker.Err()
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
stat := walker.Stat()
|
|
if stat.IsDir() {
|
|
dirs = append(dirs, walker.Path())
|
|
} else {
|
|
files = append(files, walker.Path())
|
|
}
|
|
}
|
|
|
|
err := walker.Err()
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
return dirs, files, nil
|
|
}
|
|
|
|
// Exists Determine if a directory, file, or link exists
|
|
func (s Sync) Exists(path string) (bool, error) {
|
|
_, err := s.sftpClients[0].Lstat(path)
|
|
|
|
if err != nil {
|
|
if err.Error() == "sftp: \"No such file\" (SSH_FX_NO_SUCH_FILE)" {
|
|
return false, nil
|
|
}
|
|
|
|
return false, err
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
func getProgress(filePath string, workerNumber int) (int64, error) {
|
|
file, err := os.Open(getProgressPath(filePath))
|
|
if err != nil {
|
|
if os.IsNotExist(err) {
|
|
return NoProgress, nil
|
|
}
|
|
|
|
return 0, err
|
|
}
|
|
|
|
fileStat, err := file.Stat()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
if fileStat.Size() == 0 {
|
|
return NoProgress, nil
|
|
}
|
|
|
|
var progress int64
|
|
progressSize := int64(binary.Size(progress))
|
|
offset := progressSize * int64(workerNumber - 1)
|
|
|
|
realOffset, err := file.Seek(offset, os.SEEK_SET)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
if realOffset != offset {
|
|
return 0, errors.New("getProgress: Tried to seek to " + string(offset) + " but got " + string(realOffset) + " instead")
|
|
}
|
|
|
|
progressData := make([]byte, progressSize)
|
|
|
|
read, err := file.Read(progressData)
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
return NoProgress, nil
|
|
}
|
|
|
|
return 0, err
|
|
}
|
|
|
|
if int64(read) != progressSize {
|
|
return NoProgress, nil
|
|
}
|
|
|
|
err = binary.Read(bytes.NewReader(progressData), binary.BigEndian, &progress)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return progress, nil
|
|
}
|
|
|
|
func updateProgress(filePath string, written int64, workerNumber int) error {
|
|
file, err := os.OpenFile(getProgressPath(filePath), os.O_WRONLY | os.O_CREATE, 0777)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
writtenSize := int64(binary.Size(written))
|
|
offset := writtenSize * int64(workerNumber - 1)
|
|
|
|
realOffset, err := file.Seek(offset, os.SEEK_SET)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if realOffset != offset {
|
|
return errors.New("updateProgress: Tried to seek to " + string(offset) + " but got " + string(realOffset) + " instead")
|
|
}
|
|
|
|
return binary.Write(file, binary.BigEndian, written)
|
|
}
|
|
|
|
func destroyProgress(filePath string) error {
|
|
err := os.Remove(getProgressPath(filePath))
|
|
if err != nil && !os.IsNotExist(err) {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func getProgressPath(filePath string) string {
|
|
return filepath.Join(filepath.Dir(filePath), "." + filepath.Base(filePath) + ".progress")
|
|
}
|