hoarder/sync.go

555 lines
12 KiB
Go

// Send and receive files via SFTP using multiple download streams concurrently (for downloads).
package main
import (
"bytes"
"encoding/binary"
"errors"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
"io"
"log"
"os"
"path/filepath"
"strconv"
"time"
)
const (
// NoProgress no progress for current file
NoProgress int64 = -1
)
// Sync Keeps track of the connection information.
type Sync struct {
sshClient *ssh.Client
sftpClients []*sftp.Client
sftpClientCount int
}
// NewSync Create a new Sync object, connect to the SSH server, and create sftp clients
func NewSync(threads string, user string, pass string, server string, port string) (*Sync, error) {
// convert the threads input to an int
clientCount, err := strconv.Atoi(threads)
if err != nil {
return nil, err
}
if clientCount < 1 {
return nil, errors.New("Must have a thread count >= 1")
}
sshClient, err := newSSHClient(user, pass, server, port)
if err != nil {
return nil, err
}
// initialize a total of client_count sftp clients
sftpClients := make([]*sftp.Client, clientCount)
for i := 0; i < clientCount; i++ {
sftpClient, err := sftp.NewClient(sshClient)
if err != nil {
return nil, err
}
sftpClients[i] = sftpClient
}
return &Sync{sshClient, sftpClients, clientCount}, nil
}
// Close Closes all of the ssh and sftp connections to the SSH server.
func (s Sync) Close() error {
var returnError error
for i := 0; i < s.sftpClientCount; i++ {
err := s.sftpClients[i].Close()
if err != nil {
returnError = err
}
}
err := s.sshClient.Close()
if err != nil {
return err
}
return returnError
}
// Create a new SSH client instance and confirm that we can make sessions
func newSSHClient(user string, pass string, server string, port string) (*ssh.Client, error) {
sshConfig := &ssh.ClientConfig{
User: user,
Auth: []ssh.AuthMethod{
ssh.Password(pass),
},
}
sshClient, err := ssh.Dial("tcp", server+":"+port, sshConfig)
if err != nil {
return sshClient, err
}
session, err := sshClient.NewSession()
if err != nil {
return sshClient, err
}
defer session.Close()
return sshClient, err
}
// SendFiles Send a list of files in the format of {"source_path": "destination"} to the SSH server. This does not handle directories.
func (s Sync) SendFiles(files map[string]string) error {
return SendFiles(s.sftpClients[0], files)
}
// SendFiles Send a list of files in the format of {"source_path": "destination"} to the SSH server. This does not handle directories.
func SendFiles(sftpClient *sftp.Client, files map[string]string) error {
for sourceFile, destinationFile := range files {
// 512KB buffer for reading/sending data
data := make([]byte, 524288)
// Open file that we will be sending
sourceData, err := os.Open(sourceFile)
if err != nil {
log.Println("SendFiles: Failed to open source file " + sourceFile)
return err
}
// Get the info of the file that we will be sending
sourceStat, err := sourceData.Stat()
if err != nil {
log.Println("SendFiles: Failed to stat source file " + sourceFile)
return err
}
// Extract the size of the file that we will be sending
sourceSize := sourceStat.Size()
// Create the destination file for the source file we're sending
newFile, err := sftpClient.Create(destinationFile)
if err != nil {
log.Println("SendFiles: Failed to create destination file " + destinationFile)
return err
}
// Track our position in reading/writing the file
var currentPosition int64
for currentPosition < sourceSize {
// If the next iteration will be greater than the file size, reduce to the data size
if currentPosition+int64(len(data)) > sourceSize {
data = make([]byte, sourceSize-currentPosition)
}
// Read data from the source file
read, err := sourceData.Read(data)
if err != nil {
// If it's the end of the file and we didn't read anything, break
if err == io.EOF {
if read == 0 {
break
}
} else {
return err
}
}
// Write the data from the source file to the destination file
_, err = newFile.Write(data)
if err != nil {
return err
}
// Update the current position in the file
currentPosition += int64(read)
}
// close the source file
err = sourceData.Close()
if err != nil {
return err
}
// close the destination file
err = newFile.Close()
if err != nil {
return err
}
}
return nil
}
// GetFile Get a file from the source_file path to be stored in destination_file path
func (s Sync) GetFile(sourceFile string, destinationFile string) error {
// Store channels for all the concurrent download parts
channels := make([]chan error, s.sftpClientCount)
// Make channels for all the concurrent downloads
for i := 0; i < s.sftpClientCount; i++ {
channels[i] = make(chan error)
}
// Start the concurrent downloads
for i := 0; i < s.sftpClientCount; i++ {
go GetFile(s.sftpClients[i], sourceFile, destinationFile, i+1, s.sftpClientCount, channels[i])
}
// Block until all downloads are completed or one errors
allDone := false
for !allDone {
allDone = true
for i, channel := range channels {
if channel == nil {
continue
}
select {
case err := <-channel:
if err != nil {
return err
}
channels[i] = nil
break
default:
// still running
if allDone {
allDone = false
}
break
}
}
time.Sleep(time.Second)
}
err := destroyProgress(destinationFile)
if err != nil {
return err
}
return nil
}
// GetFile Get a file from the source_file path to be stored in destination_file path.
// worker_number and work_total are not zero indexed, but 1 indexed
func GetFile(sftpClient *sftp.Client, sourceFile string, destinationFile string, workerNumber int, workerTotal int, com chan<- error) error {
// Open source_data for reading
sourceData, err := sftpClient.OpenFile(sourceFile, os.O_RDONLY)
if err != nil {
com <- err
return err
}
// Get info for source_data
stat, err := sourceData.Stat()
if err != nil {
com <- err
return err
}
// Extract the size of source_data
statSize := stat.Size()
// Calculate which byte to start reading data from
start, err := getProgress(destinationFile, workerNumber)
if err != nil {
com <- err
return err
}
if start == NoProgress {
if workerNumber == 1 {
start = 0
} else {
start = (statSize * int64(workerNumber-1)) / int64(workerTotal)
}
}
// Calculate which byte to stop reading data from
var stop int64
if workerNumber == workerTotal {
stop = statSize
} else {
stop = (statSize * int64(workerNumber)) / int64(workerTotal)
}
// Create the new file for writing
newFile, err := os.OpenFile(destinationFile, os.O_WRONLY|os.O_CREATE, 0777)
if err != nil {
com <- err
return err
}
// Seek to the computed start point
offset, err := sourceData.Seek(start, 0)
if err != nil {
com <- err
return err
}
// Seeking messed up real bad
if offset != start {
err = errors.New("Returned incorrect offset for source " + sourceFile)
com <- err
return err
}
// Seek to the computed start point
offset, err = newFile.Seek(start, 0)
if err != nil {
com <- err
return err
}
// Seeking messed up real bad
if offset != start {
err = errors.New("Return incorrect offset for destination " + destinationFile)
com <- err
return err
}
// 512KB chunks
var dataSize int64 = 524288
// Change the size if the chunk is larger than the file
chunkDifference := stop - start
if chunkDifference < dataSize {
dataSize = chunkDifference
}
// Initialize the buffer for reading/writing
data := make([]byte, dataSize)
var currentSize int64
for currentSize = start; currentSize < stop; currentSize += dataSize {
err = updateProgress(destinationFile, currentSize, workerNumber)
if err != nil {
com <- err
return err
}
// Adjust the size of the buffer if the next iteration will be greater than what has yet to be read
if currentSize+dataSize > stop {
dataSize = stop - currentSize
data = make([]byte, dataSize)
}
// Read the chunk
read, err := sourceData.Read(data)
if err != nil {
// Exit the loop if we're at the end of the file and no data was read
if err == io.EOF {
if read == 0 {
break
}
} else {
com <- err
return err
}
}
// Write the chunk
_, err = newFile.Write(data)
if err != nil {
com <- err
return err
}
}
err = updateProgress(destinationFile, currentSize, workerNumber)
if err != nil {
com <- err
return err
}
// Close out the files
err = sourceData.Close()
if err != nil {
com <- err
return err
}
err = newFile.Close()
if err != nil {
com <- err
return err
}
com <- nil
return nil
}
// GetPath Get a given directory or file defined by source_path and save it to destination_path
func (s Sync) GetPath(sourcePath string, destinationPath string) error {
// Get all the dirs and files underneath source_path
dirs, files, err := s.getChildren(sourcePath)
// Remove the trailing slash if it exists
if sourcePath[len(sourcePath)-1] == '/' {
sourcePath = sourcePath[:len(sourcePath)-1]
}
// Get the parent path of source_path
sourceBase := filepath.Dir(sourcePath)
sourceBaseLen := len(sourceBase)
// Make all the directories in destination_path
for _, dir := range dirs {
dir = filepath.Join(destinationPath, filepath.FromSlash(dir[sourceBaseLen:]))
err = os.MkdirAll(dir, 0777)
if err != nil {
return err
}
}
// Get all the files and place them in destination_path
for _, file := range files {
newFile := filepath.Join(destinationPath, filepath.FromSlash(file[sourceBaseLen:]))
err = s.GetFile(file, newFile)
if err != nil {
return err
}
}
return nil
}
// Get the directories and files underneath a given sftp root path
func (s Sync) getChildren(root string) ([]string, []string, error) {
// Used to walk through the path
walker := s.sftpClients[0].Walk(root)
// Keep track of the directories
var dirs []string
// Keep track of the files
var files []string
// Walk through the files and directories
for walker.Step() {
err := walker.Err()
if err != nil {
return nil, nil, err
}
stat := walker.Stat()
if stat.IsDir() {
dirs = append(dirs, walker.Path())
} else {
files = append(files, walker.Path())
}
}
err := walker.Err()
if err != nil {
return nil, nil, err
}
return dirs, files, nil
}
// Exists Determine if a directory, file, or link exists
func (s Sync) Exists(path string) (bool, error) {
_, err := s.sftpClients[0].Lstat(path)
if err != nil {
if err.Error() == "sftp: \"No such file\" (SSH_FX_NO_SUCH_FILE)" {
return false, nil
}
return false, err
}
return true, nil
}
func getProgress(filePath string, workerNumber int) (int64, error) {
file, err := os.Open(getProgressPath(filePath))
if err != nil {
if os.IsNotExist(err) {
return NoProgress, nil
}
return 0, err
}
fileStat, err := file.Stat()
if err != nil {
return 0, err
}
if fileStat.Size() == 0 {
return NoProgress, nil
}
var progress int64
progressSize := int64(binary.Size(progress))
offset := progressSize * int64(workerNumber-1)
realOffset, err := file.Seek(offset, os.SEEK_SET)
if err != nil {
return 0, err
}
if realOffset != offset {
return 0, errors.New("getProgress: Tried to seek to " + string(offset) + " but got " + string(realOffset) + " instead")
}
progressData := make([]byte, progressSize)
read, err := file.Read(progressData)
if err != nil {
if err == io.EOF {
return NoProgress, nil
}
return 0, err
}
if int64(read) != progressSize {
return NoProgress, nil
}
err = binary.Read(bytes.NewReader(progressData), binary.BigEndian, &progress)
if err != nil {
return 0, err
}
return progress, nil
}
func updateProgress(filePath string, written int64, workerNumber int) error {
file, err := os.OpenFile(getProgressPath(filePath), os.O_WRONLY|os.O_CREATE, 0777)
if err != nil {
return err
}
writtenSize := int64(binary.Size(written))
offset := writtenSize * int64(workerNumber-1)
realOffset, err := file.Seek(offset, os.SEEK_SET)
if err != nil {
return err
}
if realOffset != offset {
return errors.New("updateProgress: Tried to seek to " + string(offset) + " but got " + string(realOffset) + " instead")
}
return binary.Write(file, binary.BigEndian, written)
}
func destroyProgress(filePath string) error {
err := os.Remove(getProgressPath(filePath))
if err != nil && !os.IsNotExist(err) {
return err
}
return nil
}
func getProgressPath(filePath string) string {
return filepath.Join(filepath.Dir(filePath), "."+filepath.Base(filePath)+".progress")
}