#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import distutils.util
import os.path
import subprocess
import sys
import xml.etree.ElementTree as ET

def xml_from_stdin():
    tree = ET.parse(sys.stdin)
    return(tree.getroot())

def getsize(dev):
    pr = subprocess.run(["/sbin/blockdev", "--getsz", dev], stdout=subprocess.PIPE, check=True)
    return int(pr.stdout)

def mkloop(source):
    pr = subprocess.run(["/sbin/losetup", "--find", "--show", source], stdout=subprocess.PIPE, check=True)
    return pr.stdout.decode().rstrip()

def mkdisk(virtdisk, ns):
    mapping = []
    partdata = []
    loopdevs = []
    pos = 0
    virtmbr = virtdisk.get('virtmbr')

    creatembr = not os.path.isfile(virtmbr)
    if creatembr:
        with open(virtmbr, 'wb') as mbr:
            mbr.write(b"\0"*(1<<20))
    loop = mkloop(virtmbr)
    loopdevs.append(loop)
    size = getsize(loop)
    mapping.append("{} {} linear {} 0".format(pos, size, loop))
    pos += size

    for part in virtdisk.iterfind('papa:part', ns):
        source = part.get('source')
        if (distutils.util.strtobool(part.get('umount', 'no'))):
            subprocess.run(["umount", source])
        size = getsize(source)
        mapping.append("{} {} linear {} 0".format(pos, size, source))
        partdata.append("{},{},{}".format(pos, size, part.get('type', 7)))
        pos += size

    pr = subprocess.Popen(['/sbin/dmsetup', 'create', virtdisk.get('name')], stdin=subprocess.PIPE)
    pr.communicate(input="\n".join(mapping).encode())
    if creatembr:
        pr = subprocess.Popen(['/sbin/sfdisk', '--wipe=never', '--no-reread', '/dev/mapper/{}'.format(virtdisk.get('name'))], stdin=subprocess.PIPE)
        pr.communicate(input="\n".join(partdata).encode())
    for loop in loopdevs:
        subprocess.run(["losetup", "-d", loop])

def deldisk(virtdisk, ns):
    subprocess.run(['/sbin/dmsetup', 'remove', virtdisk.get('name')])
    for part in virtdisk.iterfind('papa:part', ns):
        source = part.get('source')
        if (distutils.util.strtobool(part.get('mount', 'no'))):
            subprocess.run(["mount", source])

def main():
    root = xml_from_stdin()
    ns = {'papa': 'http://www.unix-ag.uni-kl.de/~t_schmid/partpass'}
    if sys.argv[2:4] == ["prepare", "begin"]: # "/etc/libvirt/hooks/qemu guest_name prepare begin -"
        for virtdisk in root.iterfind('./metadata/papa:partpass/papa:virtdisk', ns):
            mkdisk(virtdisk, ns)
    elif sys.argv[2:4] == ["release", "end"]: # "/etc/libvirt/hooks/qemu guest_name release end -"
        for virtdisk in root.iterfind('./metadata/papa:partpass/papa:virtdisk', ns):
            deldisk(virtdisk, ns)

if __name__ == "__main__":
    main()
