#!/usr/bin/env python3
# remote-run - Runs a command on another machine, for testing -----*- python -*-
#
# This source file is part of the Swift.org open source project
#
# Copyright (c) 2018 Apple Inc. and the Swift project authors
# Licensed under Apache License v2.0 with Runtime Library Exception
#
# See https://swift.org/LICENSE.txt for license information
# See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
#
# ----------------------------------------------------------------------------

import argparse
import os
import posixpath
import subprocess
import sys
import shutil

def quote(arg):
    return repr(arg)

class CommandRunner(object):
    def __init__(self):
        self.verbose = False
        self.dry_run = False
        self.ignore_rsync_failure = False

    @staticmethod
    def _dirnames(files):
        return sorted(set(posixpath.dirname(f) for f in files))

    def popen(self, command, **kwargs):
        if self.verbose:
            print(' '.join(command), file=sys.stderr)
        if self.dry_run:
            return None
        return subprocess.Popen(command, **kwargs)

    def mkdirs_remote(self, directories):
        if directories:
            mkdir_command = ['/bin/mkdir', '-p'] + directories
            self.run_remote(mkdir_command)

    def send(self, input_prefix, remote_prefix, local_to_remote_files):
        # Prepare the remote directory structure.
        self.mkdirs_remote([remote_prefix])

        self.run_rsync_to(input_prefix, local_to_remote_files, remote_prefix)

    def fetch(self, output_prefix, remote_prefix, remote_to_local_files):
        # Prepare the local directory structure.
        mkdir_command = ['/bin/mkdir', '-p', output_prefix]
        if self.verbose:
            print(' '.join(mkdir_command), file=sys.stderr)
        if not self.dry_run:
            subprocess.check_call(mkdir_command)

        self.run_rsync_from(remote_prefix, remote_to_local_files, output_prefix)

    # Recover from random and transient errors that occur when SSHing to devices, in particular devicecompute
    @staticmethod
    def should_remote_command_retry(stderr):
        if "banner line contains invalid characters" in stderr:
            return True

        if "Connection to localhost closed by remote host" in stderr:
            return True

        if "kex_exchange_identification: Connection closed by remote host" in stderr:
            return True

        if "rsync error: unexplained error" in stderr:
            return True

        # Fallthrough. The error is not known and shouldn't be retried
        return False

    def run_remote(self, command, remote_env={}):
        env_strings = ['{0}={1}'.format(k,v) for k,v in sorted(remote_env.items())]
        remote_invocation = self.remote_invocation(
            ['/usr/bin/env'] + env_strings + command)
        attempt = 1
        remote_proc = None

        while attempt <= 3:
            remote_proc = self.popen(
                remote_invocation,
                stdin=subprocess.PIPE,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE
            )

            if self.dry_run:
                return

            stdout, stderr = remote_proc.communicate()
            stdout = stdout.decode('utf-8')
            stderr = stderr.decode('utf-8')

            # Print stdout to screen
            print(stdout, end='')

            # This is a transient and random error, and if this occurs, we should simply retry our ssh command.
            if self.should_remote_command_retry(stderr):
                attempt += 1
                continue

            print(stderr, end='', file=sys.stderr)

            # Process error code
            if remote_proc.returncode:
                # FIXME: We may still want to fetch the output files to see what
                # went wrong.
                sys.exit(remote_proc.returncode)
            else:
                # Nothing went wrong. Return
                return

        if attempt > 3 and remote_proc is not None:
            sys.exit(remote_proc.returncode)

    def run_rsync(self, invocation, sources):
        attempt = 1
        rsync_proc = None

        while attempt <= 3:
            rsync_proc = self.popen(
                invocation,
                stdin=subprocess.PIPE,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE
            )

            if self.dry_run:
                return

            sources = '\n'.join(sources)
            if self.verbose:
                print(sources, file=sys.stderr)
            stdout, stderr = rsync_proc.communicate(sources.encode('utf-8'))
            stdout = stdout.decode('utf-8')
            stderr = stderr.decode('utf-8')

            # Print stdout to screen
            print(stdout, end='')

            # This is a transient and random error, and if this occurs, we should simply retry our ssh command.
            if self.should_remote_command_retry(stderr):
                attempt += 1
                continue

            print(stderr, end='', file=sys.stderr)

            # Process error code
            if rsync_proc.returncode:
                if self.ignore_rsync_failure:
                    return
                else:
                    sys.exit(rsync_proc.returncode)
            else:
                # Nothing went wrong. Return
                return

        if attempt > 3 and rsync_proc is not None:
            sys.exit(rsync_proc.returncode)

    def run_rsync_to(self, prefix, sources, dest):
        self.run_rsync(self.rsync_to_invocation(prefix, dest), sources)

    def run_rsync_from(self, prefix, sources, dest):
        self.run_rsync(self.rsync_from_invocation(prefix, dest), sources)

class RemoteCommandRunner(CommandRunner):
    def __init__(self, host, identity_path, ssh_options, config_file):
        if ':' in host:
            (self.remote_host, self.port) = host.rsplit(':', 1)
        else:
            self.remote_host = host
            self.port = None
        self.identity_path = identity_path
        self.ssh_options = ssh_options
        self.config_file = config_file

    def common_options(self, port_flag):
        port_option = [port_flag, self.port] if self.port else []
        config_option = ['-F', self.config_file] if self.config_file else []
        identity_option = (
            ['-i', self.identity_path] if self.identity_path else [])
        # Interleave '-o' with each custom option.
        # From https://stackoverflow.com/a/8168526,
        # with explanatory help from
        # https://spapas.github.io/2016/04/27/python-nested-list-comprehensions/
        extra_options = [arg for option in self.ssh_options
                             for arg in ["-o", option]]
        return port_option + identity_option + config_option + extra_options

    def remote_invocation(self, command):
        return (['/usr/bin/ssh', '-n'] +
                self.common_options(port_flag='-p') +
                [self.remote_host, '--'] +
                [quote(arg) for arg in command])

    def rsync_invocation(self, source, dest):
        return ['/usr/bin/rsync', '-arRz', '--files-from=-',
                '-e', ' '.join([quote(x) for x in
                                ['/usr/bin/ssh'] +
                                self.common_options(port_flag='-p')]),
                source,
                dest]

    def rsync_from_invocation(self, source, dest):
        return self.rsync_invocation(self.remote_host + ':' + source, dest)

    def rsync_to_invocation(self, source, dest):
        return self.rsync_invocation(source, self.remote_host + ':' + dest)

class LocalCommandRunner(CommandRunner):
    def remote_invocation(self, command):
        return command

    def rsync_invocation(self, source, dest):
        return ['/usr/bin/rsync', '-arz', '--files-from=-', source, dest]

    def rsync_from_invocation(self, source, dest):
        return self.rsync_invocation(source, dest)

    def rsync_to_invocation(self, source, dest):
        return self.rsync_invocation(source, dest)

def strip_sep(name):
    lsep = len(posixpath.sep)
    while name.startswith(posixpath.sep):
        name = name[lsep:]
    return name

class RemotePathSet(object):
    def __init__(self):
        self.inputs = set()
        self.nodir_inputs = set()
        self.existing_outputs = set()
        self.outputs = set()
        self.existing_nodir_outputs = set()
        self.nodir_outputs = set()

class PrefixProcessor(object):
    def __init__(self,
                 input_prefix, output_prefix,
                 remote_dir, remote_input_prefix,
                 remote_output_prefix,
                 path_set):
        assert not remote_input_prefix.startswith('..')
        assert not remote_output_prefix.startswith('..')

        self.input_prefix = input_prefix
        self.output_prefix = output_prefix
        self.remote_dir = remote_dir
        self.remote_input_prefix = remote_input_prefix
        self.remote_output_prefix = remote_output_prefix
        self.path_set = path_set

        if self.input_prefix:
            while self.input_prefix.endswith(posixpath.sep):
                self.input_prefix = self.input_prefix[:-len(posixpath.sep)]
        if self.output_prefix:
            while self.output_prefix.endswith(posixpath.sep):
                self.output_prefix = self.output_prefix[:-len(posixpath.sep)]

        if self.input_prefix:
            split = posixpath.split(self.input_prefix)
            self.input_prefix_split = (len(self.input_prefix), split[0], split[1])
        else:
            self.input_prefix_split = (0, '', '')

        if self.output_prefix:
            split = posixpath.split(self.output_prefix)
            self.output_prefix_split = (len(self.output_prefix),
                                        split[0],
                                        split[1])
        else:
            self.output_prefix_split = (0, '', '')

    def process_one(self, orig_name):
        iplen, ipfxdir, ipfx = self.input_prefix_split
        oplen, opfxdir, opfx = self.output_prefix_split

        if iplen and orig_name.startswith(self.input_prefix):
            name = orig_name[iplen:]
            if not name.startswith(posixpath.sep):
                name = ipfx + name
                self.path_set.nodir_inputs.add(name)
            else:
                name = strip_sep(name)
                self.path_set.inputs.add(name)
            return posixpath.join(self.remote_dir,
                                  self.remote_input_prefix,
                                  name)

        if oplen and orig_name.startswith(self.output_prefix):
            name = orig_name[oplen:]
            if not name.startswith(posixpath.sep):
                name = opfx + name
                self.path_set.nodir_outputs.add(name)
                if os.path.exists(orig_name):
                    self.path_set.existing_nodir_outputs.add(name)
            else:
                name = strip_sep(name)
                self.path_set.outputs.add(name)
                if os.path.exists(orig_name):
                    self.path_set.existing_outputs.add(name)
            return posixpath.join(self.remote_dir,
                                  self.remote_output_prefix,
                                  name)

        return orig_name

class ArgumentProcessor(PrefixProcessor):
    def __init__(self,
                 input_prefix, output_prefix,
                 remote_dir, remote_input_prefix, remote_output_prefix,
                 path_set, arguments):
        super().__init__(input_prefix, output_prefix,
                         remote_dir, remote_input_prefix, remote_output_prefix,
                         path_set)
        self.original_args = arguments

    def process_args(self):
        self.args = []

        for arg in self.original_args:
            self.args.append(self.process_one(arg))

class EnvVarProcessor(PrefixProcessor):
    def __init__(self,
                 input_prefix, output_prefix,
                 remote_dir, remote_input_prefix, remote_output_prefix,
                 path_set, environment):
        super().__init__(input_prefix, output_prefix,
                         remote_dir, remote_input_prefix, remote_output_prefix,
                         path_set)
        self.original_env = environment

    def process_env(self):
        self.env = dict()

        for key, value in self.original_env.items():
            self.env[key] = self.process_one(value)

def collect_remote_env(local_env=os.environ, prefix='REMOTE_RUN_CHILD_'):
    return dict((key[len(prefix):], value)
                for key, value in local_env.items() if key.startswith(prefix))

def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('-v', '--verbose', action='store_true', dest='verbose',
                        help='print commands as they are run')
    parser.add_argument('-n', '--dry-run', action='store_true', dest='dry_run',
                        help="print the commands that would have been run, but "
                             "don't actually run them")

    parser.add_argument('--remote-dir', required=True, metavar='PATH',
                        help='(required) a writable temporary path on the '
                             'remote machine')
    parser.add_argument('--input-prefix',
                        help='arguments matching this prefix will be uploaded')
    parser.add_argument('--output-prefix',
                        help='arguments matching this prefix will be both '
                             'uploaded and downloaded')
    parser.add_argument('--remote-input-prefix', default='input',
                        help='input arguments use this prefix on the remote '
                             'machine')
    parser.add_argument('--remote-output-prefix', default='output',
                        help='output arguments use this prefix on the remote '
                             'machine')

    parser.add_argument('-i', '--identity', dest='identity', metavar='FILE',
                        help='an SSH identity file (private key) to use')
    parser.add_argument('-F', '--config-file', dest='config_file', metavar='FILE',
                        help='an SSH configuration file')
    parser.add_argument('-o', '--ssh-option', action='append', default=[],
                        dest='ssh_options', metavar='OPTION',
                        help='extra SSH config options (man ssh_config)')
    parser.add_argument('--debug-as-local', action='store_true',
                        help='run commands locally instead of over SSH, for '
                             'debugging purposes. The "host" argument is '
                             'omitted.')
    parser.add_argument('--ignore-rsync-failure', action='store_true',
                        help='ignore rsync failures, for debugging.')

    parser.add_argument('host',
                        help='the host to connect to, in the form '
                             '[user@]host[:port]')
    parser.add_argument('command', nargs=argparse.REMAINDER,
                        help='the command to run', metavar='command...')
    args = parser.parse_args()

    if args.debug_as_local:
        runner = LocalCommandRunner()
        args.command.insert(0, args.host)
        del args.host
    else:
        runner = RemoteCommandRunner(args.host,
                                     args.identity,
                                     args.ssh_options,
                                     args.config_file)
    runner.dry_run = args.dry_run
    runner.verbose = args.verbose or args.dry_run
    runner.ignore_rsync_failure = args.ignore_rsync_failure

    assert not args.remote_dir == '/'

    path_set = RemotePathSet()

    argproc = ArgumentProcessor(args.input_prefix,
                                args.output_prefix,
                                args.remote_dir,
                                posixpath.normpath(args.remote_input_prefix),
                                posixpath.normpath(args.remote_output_prefix),
                                path_set,
                                args.command)

    argproc.process_args()

    envproc = EnvVarProcessor(args.input_prefix,
                              args.output_prefix,
                              args.remote_dir,
                              posixpath.normpath(args.remote_input_prefix),
                              posixpath.normpath(args.remote_output_prefix),
                              path_set,
                              collect_remote_env())

    envproc.process_env()

    input_dir = posixpath.join(args.remote_dir, args.remote_input_prefix)
    output_dir = posixpath.join(args.remote_dir, args.remote_output_prefix)

    if args.output_prefix:
        runner.run_remote(['/bin/rm', '-rf', output_dir])

    dirs = set()
    for output in path_set.outputs:
        dirs.add(posixpath.join(output_dir, posixpath.dirname(output)))
    for output in path_set.nodir_outputs:
        dirs.add(posixpath.join(output_dir, posixpath.dirname(output)))
    runner.mkdirs_remote(list(dirs))

    if path_set.inputs:
        runner.send(args.input_prefix, input_dir, path_set.inputs)

    if path_set.nodir_inputs:
        runner.send(posixpath.dirname(args.input_prefix),
                    input_dir, path_set.nodir_inputs)

    if path_set.existing_outputs:
        runner.send(args.output_prefix, output_dir, path_set.existing_outputs)

    if path_set.existing_nodir_outputs:
        runner.send(posixpath.dirname(args.output_prefix),
                    output_dir, path_set.existing_nodir_outputs)

    runner.run_remote(argproc.args, envproc.env)

    if path_set.outputs:
        runner.fetch(args.output_prefix, output_dir, path_set.outputs)

    if path_set.nodir_outputs:
        runner.fetch(posixpath.dirname(args.output_prefix),
                     output_dir, path_set.nodir_outputs)

if __name__ == "__main__":
    main()
