From d36b7ff8d67f8e0c4db538661693e159e0cd9d11 Mon Sep 17 00:00:00 2001 From: Tony Blyler Date: Mon, 29 Dec 2014 18:58:23 -0500 Subject: [PATCH] Cleaned up code with gofmt/lint. Added download resuming. Made downloads for subdirectories work correctly. --- hoarder.go | 168 ++++++++++++++++----------- rtorrent.go | 44 +++---- sync.go | 321 +++++++++++++++++++++++++++++++++++++--------------- 3 files changed, 353 insertions(+), 180 deletions(-) diff --git a/hoarder.go b/hoarder.go index 80e35a1..64b832a 100644 --- a/hoarder.go +++ b/hoarder.go @@ -14,16 +14,16 @@ import ( ) // Load information from a given config file config_path -func loadConfig(config_path string) (map[string]string, error) { - file, err := os.Open(config_path) +func loadConfig(configPath string) (map[string]string, error) { + file, err := os.Open(configPath) if err != nil { - log.Println("Failed to open configuration file " + config_path) + log.Println("Failed to open configuration file " + configPath) return nil, err } data, err := ioutil.ReadAll(file) if err != nil { - log.Println("Failed to read configuration file " + config_path) + log.Println("Failed to read configuration file " + configPath) return nil, err } @@ -38,27 +38,30 @@ func loadConfig(config_path string) (map[string]string, error) { } // Ignore malformed lines - sep_position := strings.Index(line, ": \"") - if sep_position == -1 { + sepPosition := strings.Index(line, ": \"") + if sepPosition == -1 { continue } - config[line[:sep_position]] = line[sep_position + 3:len(line) - 1] + config[line[:sepPosition]] = line[sepPosition + 3:len(line) - 1] } return config, nil } // Checker routine to see if torrents are completed -func checker(config map[string]string, checker_chan <- chan map[string]string, com chan <- error) error { +func checker(config map[string]string, checkerChan <- chan map[string]string, com chan <- error) error { for { - torrent_info := <-checker_chan + torrentInfo := <-checkerChan - log.Println("Started checking " + torrent_info["torrent_path"]) + log.Println("Started checking " + torrentInfo["torrent_path"]) - torrent, err := NewTorrent(config["xml_user"], config["xml_pass"], config["xml_address"], torrent_info["torrent_path"]) + torrent, err := NewTorrent(config["xml_user"], config["xml_pass"], config["xml_address"], torrentInfo["torrent_path"]) if err != nil { - log.Println("Failed to initialize torrent for " + torrent_info["torrent_path"] + ": " + err.Error()) + if !os.IsNotExist(err) { + log.Println("Failed to initialize torrent for " + torrentInfo["torrent_path"] + ": " + err.Error()) + } + continue } @@ -69,26 +72,6 @@ func checker(config map[string]string, checker_chan <- chan map[string]string, c return err } - destination_torrent := filepath.Join(config["remote_torrent_dir"], filepath.Base(torrent.path)) - exists, err := syncer.Exists(destination_torrent) - if err != nil { - log.Println("Failed to see if " + torrent_info["torrent_path"] + " already exists on the server: " + err.Error()) - com <- err - return err - } - - if !exists { - err = syncer.SendFiles(map[string]string{torrent.path: destination_torrent}) - if err != nil { - log.Println("Failed to send " + torrent.path + " to the server: " + err.Error()) - com <- err - return err - } - - // continue because rtorrent more than likely will not finish the torrent by the next call - continue - } - completed, err := torrent.GetTorrentComplete() if err != nil { log.Println("Failed to see if " + torrent.path + " is completed: " + err.Error()) @@ -105,47 +88,47 @@ func checker(config map[string]string, checker_chan <- chan map[string]string, c if completed { log.Println(name + " is completed, starting download now") - remote_download_path := filepath.Join(config["remote_download_dir"], name) - exists, err := syncer.Exists(remote_download_path) + remoteDownloadPath := filepath.Join(config["remote_download_dir"], name) + exists, err := syncer.Exists(remoteDownloadPath) if err != nil { - log.Println("Failed to see if " + remote_download_path + " exists: " + err.Error()) + log.Println("Failed to see if " + remoteDownloadPath + " exists: " + err.Error()) com <- err return err } // file/dir to downlaod does not exist! if !exists { - err = errors.New(remote_download_path + " does not exist on remote server") + err = errors.New(remoteDownloadPath + " does not exist on remote server") com <- err return err } - completed_destination := filepath.Join(config["local_download_dir"], name) + completedDestination := filepath.Join(torrentInfo["local_download_dir"], name) - _, err = os.Stat(completed_destination) + _, err = os.Stat(completedDestination) if err == nil { - err = errors.New(completed_destination + " already exists, not downloading") + err = errors.New(completedDestination + " already exists, not downloading") continue } else if !os.IsNotExist(err) { - log.Println("Failed to stat: " + completed_destination + ": " + err.Error()) + log.Println("Failed to stat: " + completedDestination + ": " + err.Error()) com <- err return err } - err = syncer.GetPath(remote_download_path, config["temp_download_dir"]) + err = syncer.GetPath(remoteDownloadPath, config["temp_download_dir"]) if err != nil { - log.Println("Failed to download " + remote_download_path + ": " + err.Error()) + log.Println("Failed to download " + remoteDownloadPath + ": " + err.Error()) com <- err return err } log.Println("Successfully downloaded " + name) - temp_destination := filepath.Join(config["temp_download_dir"], name) + tempDestination := filepath.Join(config["temp_download_dir"], name) - err = os.Rename(temp_destination, completed_destination) + err = os.Rename(tempDestination, completedDestination) if err != nil { - log.Println("Failed to move " + temp_destination + " to " + completed_destination + ": " + err.Error()) + log.Println("Failed to move " + tempDestination + " to " + completedDestination + ": " + err.Error()) com <- err return err } @@ -166,38 +149,85 @@ func checker(config map[string]string, checker_chan <- chan map[string]string, c } // Scanner routine to see if there are new torrent_files -func scanner(config map[string]string, checker_chan chan <- map[string]string, com chan <- error) error { - watch_dirs := map[string]string{config["local_torrent_dir"]: config["local_download_dir"]} - dir_contents, err := ioutil.ReadDir(config["local_torrent_dir"]) +func scanner(config map[string]string, checkerChan chan <- map[string]string, com chan <- error) error { + watchDirs := map[string]string{config["local_torrent_dir"]: config["local_download_dir"]} + dirContents, err := ioutil.ReadDir(config["local_torrent_dir"]) if err != nil { com <- err return err } - for _, file := range dir_contents { + for _, file := range dirContents { if file.IsDir() { - watch_dirs[filepath.Join(config["local_torrent_dir"], file.Name())] = filepath.Join(config["local_download_dir"], file.Name()) + watchDirs[filepath.Join(config["local_torrent_dir"], file.Name())] = filepath.Join(config["local_download_dir"], file.Name()) } } + uploaded := make(map[string]bool) + downloadingTorrentPath := "" for { - for watch_dir, download_dir := range watch_dirs { - torrent_files, err := ioutil.ReadDir(watch_dir) + for watchDir, downloadDir := range watchDirs { + torrentFiles, err := ioutil.ReadDir(watchDir) if err != nil { com <- err return err } - for _, torrent_file := range torrent_files { - if torrent_file.IsDir() { + for _, torrentFile := range torrentFiles { + if torrentFile.IsDir() { // skip because we don't do more than one level of watching continue } - checker_chan <- map[string]string{ - "torrent_path": filepath.Join(watch_dir, torrent_file.Name()), - "local_download_dir": download_dir, + torrentPath := filepath.Join(watchDir, torrentFile.Name()) + + if !uploaded[torrentPath] { + syncer, err := NewSync("1", config["ssh_user"], config["ssh_pass"], config["ssh_server"], config["ssh_port"]) + if err != nil { + log.Println("Failed to create a new sync: " + err.Error()) + continue + } + + destinationTorrent := filepath.Join(config["remote_torrent_dir"], filepath.Base(torrentPath)) + exists, err := syncer.Exists(destinationTorrent) + if err != nil { + log.Println("Failed to see if " + torrentPath + " already exists on the server: " + err.Error()) + continue + } + + if exists { + uploaded[torrentPath] = true + } else { + err = syncer.SendFiles(map[string]string{torrentPath: destinationTorrent}) + if err == nil { + log.Println("Successfully uploaded " + torrentPath + " to " + destinationTorrent) + uploaded[torrentPath] = true + } else { + log.Println("Failed to upload " + torrentPath + " to " + destinationTorrent + ": " + err.Error()) + } + + continue + } + } + + downloadInfo := map[string]string{ + "torrent_path": torrentPath, + "local_download_dir": downloadDir, + } + + // try to send the info to the checker goroutine (nonblocking) + select { + case checkerChan <- downloadInfo: + // don't keep track of completed downloads in the uploaded map + if downloadingTorrentPath != "" { + delete(uploaded, downloadingTorrentPath) + } + + downloadingTorrentPath = torrentPath + break + default: + break } } } @@ -215,18 +245,18 @@ func main() { os.Exit(1) }) - var config_path string - flag.StringVar(&config_path, "config", "", "Location of the config file") + var configPath string + flag.StringVar(&configPath, "config", "", "Location of the config file") flag.Parse() - if config_path == "" { + if configPath == "" { log.Println("Missing argument for configuration file path") flag.PrintDefaults() os.Exit(1) } log.Println("Reading configuration file") - config, err := loadConfig(config_path) + config, err := loadConfig(configPath) if err != nil { log.Println(err) os.Exit(1) @@ -234,7 +264,7 @@ func main() { log.Println("Successfully read configuration file") - checker_chan := make(chan map[string]string, 50) + checkerChan := make(chan map[string]string, 50) if err != nil { log.Println(err) @@ -242,25 +272,27 @@ func main() { } log.Println("Starting the scanner routine") - scanner_com := make(chan error) - go scanner(config, checker_chan, scanner_com) + scannerCom := make(chan error) + go scanner(config, checkerChan, scannerCom) log.Println("Starting the checker routine") - checker_com := make(chan error) - go checker(config, checker_chan, checker_com) + checkerCom := make(chan error) + go checker(config, checkerChan, checkerCom) for { select { - case err := <-scanner_com: + case err := <-scannerCom: if err != nil { log.Println("Scanner failed: " + err.Error()) os.Exit(1) } - case err := <-checker_com: + case err := <-checkerCom: if err != nil { log.Println("Checker failed: " + err.Error()) os.Exit(1) } + default: + break } time.Sleep(time.Second * 5) diff --git a/rtorrent.go b/rtorrent.go index 40dcf83..4ec34b0 100644 --- a/rtorrent.go +++ b/rtorrent.go @@ -15,28 +15,28 @@ import ( bencode "github.com/jackpal/bencode-go" ) -// Keeps track of a torrent file's and rtorrent's XMLRPC information. +// Torrent Keeps track of a torrent file's and rtorrent's XMLRPC information. type Torrent struct { path string hash string - xml_user string - xml_pass string - xml_address string + xmlUser string + xmlPass string + xmlAddress string } -// Create a new Torrent instance while computing its hash. -func NewTorrent(xml_user string, xml_pass string, xml_address string, file_path string) (*Torrent, error) { - hash, err := getTorrentHash(file_path) +// NewTorrent Create a new Torrent instance while computing its hash. +func NewTorrent(xmlUser string, xmlPass string, xmlAddress string, filePath string) (*Torrent, error) { + hash, err := getTorrentHash(filePath) if err != nil { return nil, err } - return &Torrent{file_path, hash, xml_user, xml_pass, xml_address}, nil + return &Torrent{filePath, hash, xmlUser, xmlPass, xmlAddress}, nil } // Compute the torrent hash for a given torrent file path returning an all caps sha1 hash as a string. -func getTorrentHash(file_path string) (string, error) { - file, err := os.Open(file_path) +func getTorrentHash(filePath string) (string, error) { + file, err := os.Open(filePath) if err != nil { return "", err } @@ -56,18 +56,18 @@ func getTorrentHash(file_path string) (string, error) { var encoded bytes.Buffer bencode.Marshal(&encoded, decoded["info"]) - encoded_string := encoded.String() + encodedString := encoded.String() hash := sha1.New() - io.WriteString(hash, encoded_string) + io.WriteString(hash, encodedString) - hash_string := strings.ToUpper(hex.EncodeToString(hash.Sum(nil))) + hashString := strings.ToUpper(hex.EncodeToString(hash.Sum(nil))) - return hash_string, nil + return hashString, nil } // Send a command and its argument to the rtorrent XMLRPC and get the response. -func (t Torrent) xmlRpcSend (command string, arg string) (string, error) { +func (t Torrent) xmlRPCSend (command string, arg string) (string, error) { // This is hacky XML to send to the server buf := []byte("\n" + "\n" + @@ -81,14 +81,14 @@ func (t Torrent) xmlRpcSend (command string, arg string) (string, error) { buffer := bytes.NewBuffer(buf) - request, err := http.NewRequest("POST", t.xml_address, buffer) + request, err := http.NewRequest("POST", t.xmlAddress, buffer) if err != nil { return "", err } // Set the basic HTTP auth if we have a user or password - if t.xml_user != "" || t.xml_pass != "" { - request.SetBasicAuth(t.xml_user, t.xml_pass) + if t.xmlUser != "" || t.xmlPass != "" { + request.SetBasicAuth(t.xmlUser, t.xmlPass) } client := &http.Client{} @@ -118,14 +118,14 @@ func (t Torrent) xmlRpcSend (command string, arg string) (string, error) { return values[0][1], nil } -// Get the torrent's name from rtorrent. +// GetTorrentName Get the torrent's name from rtorrent. func (t Torrent) GetTorrentName() (string, error) { - return t.xmlRpcSend("d.get_name", t.hash) + return t.xmlRPCSend("d.get_name", t.hash) } -// Get the completion status of the torrent from rtorrent. +// GetTorrentComplete Get the completion status of the torrent from rtorrent. func (t Torrent) GetTorrentComplete() (bool, error) { - complete, err := t.xmlRpcSend("d.get_complete", t.hash) + complete, err := t.xmlRPCSend("d.get_complete", t.hash) if err != nil { return false, err } diff --git a/sync.go b/sync.go index 1600884..869bec3 100644 --- a/sync.go +++ b/sync.go @@ -2,6 +2,8 @@ package main import ( + "bytes" + "encoding/binary" "errors" "github.com/pkg/sftp" "golang.org/x/crypto/ssh" @@ -10,117 +12,121 @@ import ( "os" "path/filepath" "strconv" + "time" ) -// Keeps track of the connection information. +const ( + // NoProgress no progress for current file + NoProgress int64 = -1 +) + +// Sync Keeps track of the connection information. type Sync struct { sshClient *ssh.Client sftpClients []*sftp.Client sftpClientCount int } -// Create a new Sync object, connect to the SSH server, and create sftp clients +// NewSync Create a new Sync object, connect to the SSH server, and create sftp clients func NewSync(threads string, user string, pass string, server string, port string) (*Sync, error) { // convert the threads input to an int - client_count, err := strconv.Atoi(threads) + clientCount, err := strconv.Atoi(threads) if err != nil { return nil, err } - if client_count < 1 { + if clientCount < 1 { return nil, errors.New("Must have a thread count >= 1") } - ssh_client, err := newSSHClient(user, pass, server, port) + sshClient, err := newSSHClient(user, pass, server, port) if err != nil { return nil, err } // initialize a total of client_count sftp clients - sftp_clients := make([]*sftp.Client, client_count) - for i := 0; i < client_count; i++ { - sftp_client, err := sftp.NewClient(ssh_client) + sftpClients := make([]*sftp.Client, clientCount) + for i := 0; i < clientCount; i++ { + sftpClient, err := sftp.NewClient(sshClient) if err != nil { return nil, err } - sftp_clients[i] = sftp_client + sftpClients[i] = sftpClient } - return &Sync{ssh_client, sftp_clients, client_count}, nil + return &Sync{sshClient, sftpClients, clientCount}, nil } // Create a new SSH client instance and confirm that we can make sessions func newSSHClient(user string, pass string, server string, port string) (*ssh.Client, error) { - ssh_config := &ssh.ClientConfig{ + sshConfig := &ssh.ClientConfig{ User: user, Auth: []ssh.AuthMethod{ ssh.Password(pass), }, } - ssh_client, err := ssh.Dial("tcp", server + ":" + port, ssh_config) + sshClient, err := ssh.Dial("tcp", server + ":" + port, sshConfig) if err != nil { - return ssh_client, err + return sshClient, err } - session, err := ssh_client.NewSession() + session, err := sshClient.NewSession() if err != nil { - return ssh_client, err + return sshClient, err } defer session.Close() - return ssh_client, err + return sshClient, err } -// Send a list of files in the format of {"source_path": "destination"} to the SSH server. -// This does not handle directories. +// SendFiles Send a list of files in the format of {"source_path": "destination"} to the SSH server. This does not handle directories. func (s Sync) SendFiles(files map[string]string) error { return SendFiles(s.sftpClients[0], files) } -// Send a list of files in the format of {"source_path": "destination"} to the SSH server. -// This does not handle directories. -func SendFiles(sftp_client *sftp.Client, files map[string]string) error { - for source_file, destination_file := range files { +// SendFiles Send a list of files in the format of {"source_path": "destination"} to the SSH server. This does not handle directories. +func SendFiles(sftpClient *sftp.Client, files map[string]string) error { + for sourceFile, destinationFile := range files { // 512KB buffer for reading/sending data data := make([]byte, 524288) // Open file that we will be sending - source_data, err := os.Open(source_file) + sourceData, err := os.Open(sourceFile) if err != nil { - log.Println("SendFiles: Failed to open source file " + source_file) + log.Println("SendFiles: Failed to open source file " + sourceFile) return err } // Get the info of the file that we will be sending - source_stat, err := source_data.Stat() + sourceStat, err := sourceData.Stat() if err != nil { - log.Println("SendFiles: Failed to stat source file " + source_file) + log.Println("SendFiles: Failed to stat source file " + sourceFile) return err } // Extract the size of the file that we will be sending - source_size := source_stat.Size() + sourceSize := sourceStat.Size() // Create the destination file for the source file we're sending - new_file, err := sftp_client.Create(destination_file) + newFile, err := sftpClient.Create(destinationFile) if err != nil { - log.Println("SendFiles: Failed to create destination file " + destination_file) + log.Println("SendFiles: Failed to create destination file " + destinationFile) return err } // Track our position in reading/writing the file - var current_position int64 = 0 - for current_position < source_size { + var currentPosition int64 + for currentPosition < sourceSize { // If the next iteration will be greater than the file size, reduce to the data size - if current_position + int64(len(data)) > source_size { - data = make([]byte, source_size - current_position) + if currentPosition + int64(len(data)) > sourceSize { + data = make([]byte, sourceSize - currentPosition) } // Read data from the source file - read, err := source_data.Read(data) + read, err := sourceData.Read(data) if err != nil { // If it's the end of the file and we didn't read anything, break if err == io.EOF { @@ -133,23 +139,23 @@ func SendFiles(sftp_client *sftp.Client, files map[string]string) error { } // Write the data from the source file to the destination file - _, err = new_file.Write(data) + _, err = newFile.Write(data) if err != nil { return err } // Update the current position in the file - current_position += int64(read) + currentPosition += int64(read) } // close the source file - err = source_data.Close() + err = sourceData.Close() if err != nil { return err } // close the destination file - err = new_file.Close() + err = newFile.Close() if err != nil { return err } @@ -158,8 +164,8 @@ func SendFiles(sftp_client *sftp.Client, files map[string]string) error { return nil } -// Get a file from the source_file path to be stored in destination_file path -func (s Sync) GetFile(source_file string, destination_file string) error { +// GetFile Get a file from the source_file path to be stored in destination_file path +func (s Sync) GetFile(sourceFile string, destinationFile string) error { // Store channels for all the concurrent download parts channels := make([]chan error, s.sftpClientCount) @@ -170,64 +176,98 @@ func (s Sync) GetFile(source_file string, destination_file string) error { // Start the concurrent downloads for i := 0; i < s.sftpClientCount; i++ { - go GetFile(s.sftpClients[i], source_file, destination_file, i + 1, s.sftpClientCount, channels[i]) + go GetFile(s.sftpClients[i], sourceFile, destinationFile, i + 1, s.sftpClientCount, channels[i]) } // Block until all downloads are completed or one errors - for _, channel := range channels { - err := <- channel - if err != nil { - return err + allDone := false + for !allDone { + allDone = true + for i, channel := range channels { + if channel == nil { + continue + } + + select { + case err := <- channel: + if err != nil { + return err + } + + channels[i] = nil + break + default: + // still running + if allDone { + allDone = false + } + break + } } + + time.Sleep(time.Second) + } + + err := destroyProgress(destinationFile) + if err != nil { + return err } return nil } -// Get a file from the source_file path to be stored in destination_file path. +// GetFile Get a file from the source_file path to be stored in destination_file path. // worker_number and work_total are not zero indexed, but 1 indexed -func GetFile(sftp_client *sftp.Client, source_file string, destination_file string, worker_number int, worker_total int, com chan <- error) error { +func GetFile(sftpClient *sftp.Client, sourceFile string, destinationFile string, workerNumber int, workerTotal int, com chan <- error) error { // Open source_data for reading - source_data, err := sftp_client.OpenFile(source_file, os.O_RDONLY) + sourceData, err := sftpClient.OpenFile(sourceFile, os.O_RDONLY) if err != nil { com <- err return err } // Get info for source_data - stat, err := source_data.Stat() + stat, err := sourceData.Stat() if err != nil { com <- err return err } // Extract the size of source_data - stat_size := stat.Size() + statSize := stat.Size() + // Calculate which byte to start reading data from - var start int64 - if worker_number == 1 { - start = 0 - } else { - start = (stat_size * int64(worker_number - 1)) / int64(worker_total) + start, err := getProgress(destinationFile, workerNumber) + if err != nil { + com <- err + return err + } + + if start == NoProgress { + if workerNumber == 1 { + start = 0 + } else { + start = (statSize * int64(workerNumber - 1)) / int64(workerTotal) + } } // Calculate which byte to stop reading data from var stop int64 - if worker_number == worker_total { - stop = stat_size + if workerNumber == workerTotal { + stop = statSize } else { - stop = (stat_size * int64(worker_number)) / int64(worker_total) + stop = (statSize * int64(workerNumber)) / int64(workerTotal) } // Create the new file for writing - new_file, err := os.OpenFile(destination_file, os.O_WRONLY | os.O_CREATE, 0777) + newFile, err := os.OpenFile(destinationFile, os.O_WRONLY | os.O_CREATE, 0777) if err != nil { com <- err return err } // Seek to the computed start point - offset, err := source_data.Seek(start, 0) + offset, err := sourceData.Seek(start, 0) if err != nil { com <- err return err @@ -235,13 +275,13 @@ func GetFile(sftp_client *sftp.Client, source_file string, destination_file stri // Seeking messed up real bad if offset != start { - err = errors.New("Returned incorrect offset for source " + source_file) + err = errors.New("Returned incorrect offset for source " + sourceFile) com <- err return err } // Seek to the computed start point - offset, err = new_file.Seek(start, 0) + offset, err = newFile.Seek(start, 0) if err != nil { com <- err return err @@ -249,31 +289,37 @@ func GetFile(sftp_client *sftp.Client, source_file string, destination_file stri // Seeking messed up real bad if offset != start { - err = errors.New("Return incorrect offset for destination " + destination_file) + err = errors.New("Return incorrect offset for destination " + destinationFile) com <- err return err } // 512KB chunks - var data_size int64 = 524288 + var dataSize int64 = 524288 // Change the size if the chunk is larger than the file - chunk_difference := stop - start - if chunk_difference < data_size { - data_size = chunk_difference + chunkDifference := stop - start + if chunkDifference < dataSize { + dataSize = chunkDifference } // Initialize the buffer for reading/writing - data := make([]byte, data_size) + data := make([]byte, dataSize) + var currentSize int64 + for currentSize = start; currentSize < stop; currentSize += dataSize { + err = updateProgress(destinationFile, currentSize, workerNumber) + if err != nil { + com <- err + return err + } - for current_size := start; current_size < stop; current_size += data_size { // Adjust the size of the buffer if the next iteration will be greater than what has yet to be read - if current_size + data_size > stop { - data_size = stop - current_size - data = make([]byte, data_size) + if currentSize + dataSize > stop { + dataSize = stop - currentSize + data = make([]byte, dataSize) } // Read the chunk - read, err := source_data.Read(data) + read, err := sourceData.Read(data) if err != nil { // Exit the loop if we're at the end of the file and no data was read if err == io.EOF { @@ -287,21 +333,27 @@ func GetFile(sftp_client *sftp.Client, source_file string, destination_file stri } // Write the chunk - _, err = new_file.Write(data) + _, err = newFile.Write(data) if err != nil { com <- err return err } } - // Close out the files - err = source_data.Close() + err = updateProgress(destinationFile, currentSize, workerNumber) if err != nil { com <- err return err } - err = new_file.Close() + // Close out the files + err = sourceData.Close() + if err != nil { + com <- err + return err + } + + err = newFile.Close() if err != nil { com <- err return err @@ -311,23 +363,23 @@ func GetFile(sftp_client *sftp.Client, source_file string, destination_file stri return nil } -// Get a given directory or file defined by source_path and save it to destination_path -func (s Sync) GetPath(source_path string, destination_path string) error { +// GetPath Get a given directory or file defined by source_path and save it to destination_path +func (s Sync) GetPath(sourcePath string, destinationPath string) error { // Get all the dirs and files underneath source_path - dirs, files, err := s.getChildren(source_path) + dirs, files, err := s.getChildren(sourcePath) // Remove the trailing slash if it exists - if source_path[len(source_path) - 1] == '/' { - source_path = source_path[:len(source_path) - 1] + if sourcePath[len(sourcePath) - 1] == '/' { + sourcePath = sourcePath[:len(sourcePath) - 1] } // Get the parent path of source_path - source_base := filepath.Dir(source_path) - source_base_len := len(source_base) + sourceBase := filepath.Dir(sourcePath) + sourceBaseLen := len(sourceBase) // Make all the directories in destination_path for _, dir := range dirs { - dir = filepath.Join(destination_path, filepath.FromSlash(dir[source_base_len:])) + dir = filepath.Join(destinationPath, filepath.FromSlash(dir[sourceBaseLen:])) err = os.MkdirAll(dir, 0777) if err != nil { return err @@ -336,8 +388,8 @@ func (s Sync) GetPath(source_path string, destination_path string) error { // Get all the files and place them in destination_path for _, file := range files { - new_file := filepath.Join(destination_path, filepath.FromSlash(file[source_base_len:])) - err = s.GetFile(file, new_file) + newFile := filepath.Join(destinationPath, filepath.FromSlash(file[sourceBaseLen:])) + err = s.GetFile(file, newFile) if err != nil { return err } @@ -352,9 +404,9 @@ func (s Sync) getChildren(root string) ([]string, []string, error) { walker := s.sftpClients[0].Walk(root) // Keep track of the directories - dirs := make([]string, 0) + var dirs []string // Keep track of the files - files := make([]string, 0) + var files []string // Walk through the files and directories for walker.Step() { @@ -379,7 +431,7 @@ func (s Sync) getChildren(root string) ([]string, []string, error) { return dirs, files, nil } -// Determine if a directory, file, or link exists +// Exists Determine if a directory, file, or link exists func (s Sync) Exists(path string) (bool, error) { _, err := s.sftpClients[0].Lstat(path) @@ -393,3 +445,92 @@ func (s Sync) Exists(path string) (bool, error) { return true, nil } + +func getProgress(filePath string, workerNumber int) (int64, error) { + file, err := os.Open(getProgressPath(filePath)) + if err != nil { + if os.IsNotExist(err) { + return NoProgress, nil + } + + return 0, err + } + + fileStat, err := file.Stat() + if err != nil { + return 0, err + } + + if fileStat.Size() == 0 { + return NoProgress, nil + } + + var progress int64 + progressSize := int64(binary.Size(progress)) + offset := progressSize * int64(workerNumber - 1) + + realOffset, err := file.Seek(offset, os.SEEK_SET) + if err != nil { + return 0, err + } + + if realOffset != offset { + return 0, errors.New("getProgress: Tried to seek to " + string(offset) + " but got " + string(realOffset) + " instead") + } + + progressData := make([]byte, progressSize) + + read, err := file.Read(progressData) + if err != nil { + if err == io.EOF { + return NoProgress, nil + } + + return 0, err + } + + if int64(read) != progressSize { + return NoProgress, nil + } + + err = binary.Read(bytes.NewReader(progressData), binary.BigEndian, &progress) + if err != nil { + return 0, err + } + + return progress, nil +} + +func updateProgress(filePath string, written int64, workerNumber int) error { + file, err := os.OpenFile(getProgressPath(filePath), os.O_WRONLY | os.O_CREATE, 0777) + if err != nil { + return err + } + + writtenSize := int64(binary.Size(written)) + offset := writtenSize * int64(workerNumber - 1) + + realOffset, err := file.Seek(offset, os.SEEK_SET) + if err != nil { + return err + } + + if realOffset != offset { + return errors.New("updateProgress: Tried to seek to " + string(offset) + " but got " + string(realOffset) + " instead") + } + + return binary.Write(file, binary.BigEndian, written) +} + +func destroyProgress(filePath string) error { + err := os.Remove(getProgressPath(filePath)) + if err != nil && !os.IsNotExist(err) { + return err + } + + return nil +} + +func getProgressPath(filePath string) string { + return filepath.Join(filepath.Dir(filePath), "." + filepath.Base(filePath) + ".progress") +}