diff --git a/hoarder.conf.ex b/hoarder.conf.ex index 31df3af..1098f1a 100644 --- a/hoarder.conf.ex +++ b/hoarder.conf.ex @@ -1,6 +1,8 @@ // Configuration for Hoarder // All fields and values are necessary +// Should Hoarder automatically restart when one of its processes errors? (default true if not set) +restart_on_error: "true" // Username for XMLRPC xml_user: "testuser" // Password for XMLRPC diff --git a/hoarder.go b/hoarder.go index 64b832a..ef1934e 100644 --- a/hoarder.go +++ b/hoarder.go @@ -43,14 +43,14 @@ func loadConfig(configPath string) (map[string]string, error) { continue } - config[line[:sepPosition]] = line[sepPosition + 3:len(line) - 1] + config[line[:sepPosition]] = line[sepPosition+3 : len(line)-1] } return config, nil } // Checker routine to see if torrents are completed -func checker(config map[string]string, checkerChan <- chan map[string]string, com chan <- error) error { +func checker(config map[string]string, checkerChan <-chan map[string]string, com chan<- error) error { for { torrentInfo := <-checkerChan @@ -66,6 +66,7 @@ func checker(config map[string]string, checkerChan <- chan map[string]string, co } syncer, err := NewSync(config["threads"], config["ssh_user"], config["ssh_pass"], config["ssh_server"], config["ssh_port"]) + defer syncer.Close() if err != nil { log.Println("Failed to create a new sync: " + err.Error()) com <- err @@ -142,6 +143,8 @@ func checker(config map[string]string, checkerChan <- chan map[string]string, co } else { log.Println(name + " is not completed, waiting for it to finish") } + + syncer.Close() } com <- nil @@ -149,7 +152,7 @@ func checker(config map[string]string, checkerChan <- chan map[string]string, co } // Scanner routine to see if there are new torrent_files -func scanner(config map[string]string, checkerChan chan <- map[string]string, com chan <- error) error { +func scanner(config map[string]string, checkerChan chan<- map[string]string, com chan<- error) error { watchDirs := map[string]string{config["local_torrent_dir"]: config["local_download_dir"]} dirContents, err := ioutil.ReadDir(config["local_torrent_dir"]) @@ -186,6 +189,7 @@ func scanner(config map[string]string, checkerChan chan <- map[string]string, co syncer, err := NewSync("1", config["ssh_user"], config["ssh_pass"], config["ssh_server"], config["ssh_port"]) if err != nil { log.Println("Failed to create a new sync: " + err.Error()) + syncer.Close() continue } @@ -193,6 +197,7 @@ func scanner(config map[string]string, checkerChan chan <- map[string]string, co exists, err := syncer.Exists(destinationTorrent) if err != nil { log.Println("Failed to see if " + torrentPath + " already exists on the server: " + err.Error()) + syncer.Close() continue } @@ -207,12 +212,15 @@ func scanner(config map[string]string, checkerChan chan <- map[string]string, co log.Println("Failed to upload " + torrentPath + " to " + destinationTorrent + ": " + err.Error()) } + syncer.Close() continue } + + syncer.Close() } downloadInfo := map[string]string{ - "torrent_path": torrentPath, + "torrent_path": torrentPath, "local_download_dir": downloadDir, } @@ -239,10 +247,14 @@ func scanner(config map[string]string, checkerChan chan <- map[string]string, co return nil } +func die(exitCode int) { + log.Println("Quiting") + os.Exit(exitCode) +} + func main() { sigint.ListenForSIGINT(func() { - log.Println("Quiting") - os.Exit(1) + die(1) }) var configPath string @@ -252,14 +264,14 @@ func main() { if configPath == "" { log.Println("Missing argument for configuration file path") flag.PrintDefaults() - os.Exit(1) + die(1) } log.Println("Reading configuration file") config, err := loadConfig(configPath) if err != nil { log.Println(err) - os.Exit(1) + die(1) } log.Println("Successfully read configuration file") @@ -268,7 +280,7 @@ func main() { if err != nil { log.Println(err) - os.Exit(1) + die(1) } log.Println("Starting the scanner routine") @@ -279,17 +291,36 @@ func main() { checkerCom := make(chan error) go checker(config, checkerChan, checkerCom) + restartOnError := true + if config["restart_on_error"] != "" { + restartOnError = config["restart_on_error"] == "true" + } + for { select { case err := <-scannerCom: if err != nil { log.Println("Scanner failed: " + err.Error()) - os.Exit(1) + + if restartOnError { + log.Println("Restarting scanner") + go scanner(config, checkerChan, scannerCom) + } else { + log.Println("Quiting due to scanner error") + die(1) + } } case err := <-checkerCom: if err != nil { log.Println("Checker failed: " + err.Error()) - os.Exit(1) + + if restartOnError { + log.Println("Restarting checker") + go checker(config, checkerChan, checkerCom) + } else { + log.Println("Quiting due to checker error") + die(1) + } } default: break diff --git a/sync.go b/sync.go index 869bec3..6be8208 100644 --- a/sync.go +++ b/sync.go @@ -22,8 +22,8 @@ const ( // Sync Keeps track of the connection information. type Sync struct { - sshClient *ssh.Client - sftpClients []*sftp.Client + sshClient *ssh.Client + sftpClients []*sftp.Client sftpClientCount int } @@ -58,6 +58,24 @@ func NewSync(threads string, user string, pass string, server string, port strin return &Sync{sshClient, sftpClients, clientCount}, nil } +// Close Closes all of the ssh and sftp connections to the SSH server. +func (s Sync) Close() error { + var returnError error + for i := 0; i < s.sftpClientCount; i++ { + err := s.sftpClients[i].Close() + if err != nil { + returnError = err + } + } + + err := s.sshClient.Close() + if err != nil { + return err + } + + return returnError +} + // Create a new SSH client instance and confirm that we can make sessions func newSSHClient(user string, pass string, server string, port string) (*ssh.Client, error) { sshConfig := &ssh.ClientConfig{ @@ -67,7 +85,7 @@ func newSSHClient(user string, pass string, server string, port string) (*ssh.Cl }, } - sshClient, err := ssh.Dial("tcp", server + ":" + port, sshConfig) + sshClient, err := ssh.Dial("tcp", server+":"+port, sshConfig) if err != nil { return sshClient, err @@ -121,8 +139,8 @@ func SendFiles(sftpClient *sftp.Client, files map[string]string) error { var currentPosition int64 for currentPosition < sourceSize { // If the next iteration will be greater than the file size, reduce to the data size - if currentPosition + int64(len(data)) > sourceSize { - data = make([]byte, sourceSize - currentPosition) + if currentPosition+int64(len(data)) > sourceSize { + data = make([]byte, sourceSize-currentPosition) } // Read data from the source file @@ -176,7 +194,7 @@ func (s Sync) GetFile(sourceFile string, destinationFile string) error { // Start the concurrent downloads for i := 0; i < s.sftpClientCount; i++ { - go GetFile(s.sftpClients[i], sourceFile, destinationFile, i + 1, s.sftpClientCount, channels[i]) + go GetFile(s.sftpClients[i], sourceFile, destinationFile, i+1, s.sftpClientCount, channels[i]) } // Block until all downloads are completed or one errors @@ -189,7 +207,7 @@ func (s Sync) GetFile(sourceFile string, destinationFile string) error { } select { - case err := <- channel: + case err := <-channel: if err != nil { return err } @@ -218,7 +236,7 @@ func (s Sync) GetFile(sourceFile string, destinationFile string) error { // GetFile Get a file from the source_file path to be stored in destination_file path. // worker_number and work_total are not zero indexed, but 1 indexed -func GetFile(sftpClient *sftp.Client, sourceFile string, destinationFile string, workerNumber int, workerTotal int, com chan <- error) error { +func GetFile(sftpClient *sftp.Client, sourceFile string, destinationFile string, workerNumber int, workerTotal int, com chan<- error) error { // Open source_data for reading sourceData, err := sftpClient.OpenFile(sourceFile, os.O_RDONLY) if err != nil { @@ -247,7 +265,7 @@ func GetFile(sftpClient *sftp.Client, sourceFile string, destinationFile string, if workerNumber == 1 { start = 0 } else { - start = (statSize * int64(workerNumber - 1)) / int64(workerTotal) + start = (statSize * int64(workerNumber-1)) / int64(workerTotal) } } @@ -260,7 +278,7 @@ func GetFile(sftpClient *sftp.Client, sourceFile string, destinationFile string, } // Create the new file for writing - newFile, err := os.OpenFile(destinationFile, os.O_WRONLY | os.O_CREATE, 0777) + newFile, err := os.OpenFile(destinationFile, os.O_WRONLY|os.O_CREATE, 0777) if err != nil { com <- err return err @@ -313,7 +331,7 @@ func GetFile(sftpClient *sftp.Client, sourceFile string, destinationFile string, } // Adjust the size of the buffer if the next iteration will be greater than what has yet to be read - if currentSize + dataSize > stop { + if currentSize+dataSize > stop { dataSize = stop - currentSize data = make([]byte, dataSize) } @@ -369,8 +387,8 @@ func (s Sync) GetPath(sourcePath string, destinationPath string) error { dirs, files, err := s.getChildren(sourcePath) // Remove the trailing slash if it exists - if sourcePath[len(sourcePath) - 1] == '/' { - sourcePath = sourcePath[:len(sourcePath) - 1] + if sourcePath[len(sourcePath)-1] == '/' { + sourcePath = sourcePath[:len(sourcePath)-1] } // Get the parent path of source_path @@ -467,7 +485,7 @@ func getProgress(filePath string, workerNumber int) (int64, error) { var progress int64 progressSize := int64(binary.Size(progress)) - offset := progressSize * int64(workerNumber - 1) + offset := progressSize * int64(workerNumber-1) realOffset, err := file.Seek(offset, os.SEEK_SET) if err != nil { @@ -502,13 +520,13 @@ func getProgress(filePath string, workerNumber int) (int64, error) { } func updateProgress(filePath string, written int64, workerNumber int) error { - file, err := os.OpenFile(getProgressPath(filePath), os.O_WRONLY | os.O_CREATE, 0777) + file, err := os.OpenFile(getProgressPath(filePath), os.O_WRONLY|os.O_CREATE, 0777) if err != nil { return err } writtenSize := int64(binary.Size(written)) - offset := writtenSize * int64(workerNumber - 1) + offset := writtenSize * int64(workerNumber-1) realOffset, err := file.Seek(offset, os.SEEK_SET) if err != nil { @@ -532,5 +550,5 @@ func destroyProgress(filePath string) error { } func getProgressPath(filePath string) string { - return filepath.Join(filepath.Dir(filePath), "." + filepath.Base(filePath) + ".progress") + return filepath.Join(filepath.Dir(filePath), "."+filepath.Base(filePath)+".progress") }