import os
import winreg
import argparse

provider_dict = {
    'npu_level_zero_umd.dll' : 'Intel-NPU-LevelZero',
    'npu_d3d12_umd.dll': 'Intel-NPU-D3D12',
    'npu_kmd.sys' : 'Intel-NPU-Kmd'
}

key_path_dict = {
    'npu_level_zero_umd.dll' :
        r'SOFTWARE\Microsoft\Windows\CurrentVersion\WINEVT\Publishers\{416f823f-2ce2-44b9-a1ba-7e98ba4cd4ba}',
    'npu_d3d12_umd.dll' :
        r'SOFTWARE\Microsoft\Windows\CurrentVersion\WINEVT\Publishers\{11a83531-4ac9-4142-8d35-e474b6b3c597}',
    'npu_kmd.sys' :
        r"SOFTWARE\Microsoft\Windows\CurrentVersion\WINEVT\Publishers\{B3B1AAB1-3C04-4B6D-A069-59547BC18233}"
}

def create_reg_key(key_path):
    try:
        # Create the key
        key = winreg.CreateKey(winreg.HKEY_LOCAL_MACHINE, key_path)
        winreg.CloseKey(key)
        print(f"HKEY_LOCAL_MACHINE\\{key_path} created successfully.")

    except Exception as e:
        print("Error:", e)

def set_key_value(key_path, values):
    try:
        # Open the registry key
        key = winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, key_path, 0, winreg.KEY_WRITE)
        for name, value, value_type in values:
            winreg.SetValueEx(key, name, 0, value_type, value)
        winreg.CloseKey(key)

    except Exception as e:
        print(f"An error occurred: {e}")

def register_file(file_name:str, file_path:str):
    print(f"\nRegistering {file_path}")

    create_reg_key(key_path_dict[file_name])
    set_key_value(key_path_dict[file_name], [
        ("Enabled", 0x1, winreg.REG_DWORD),
        ("(Default)", provider_dict[file_name], winreg.REG_SZ),
        ("ResourceFileName", file_path, winreg.REG_EXPAND_SZ),
        ("ParameterFileName", file_path, winreg.REG_EXPAND_SZ),
        ("MessageFileName", file_path, winreg.REG_EXPAND_SZ)
    ])

    create_reg_key(key_path_dict[file_name] + r"\ChannelReferences")
    # Set Level Zero specific channel references
    if file_name == "npu_level_zero_umd.dll":
        set_key_value(key_path_dict[file_name] + r"\ChannelReferences", [
            ("Count", 0x02, winreg.REG_DWORD)
        ])
    else:
        set_key_value(key_path_dict[file_name] + r"\ChannelReferences", [
            ("Count", 0x01, winreg.REG_DWORD)
        ])

    create_reg_key(key_path_dict[file_name] + r"\ChannelReferences\0")
    set_key_value(key_path_dict[file_name] + r"\ChannelReferences\0", [
        ("Flags", 0x0, winreg.REG_DWORD),
        ("Id", 0x10, winreg.REG_DWORD),
        ("(Default)", f"{provider_dict[file_name]}/Operational", winreg.REG_SZ)
    ])

    # Level Zero has two channel references (Operational/Analytic)
    if file_name == "npu_level_zero_umd.dll":
        create_reg_key(key_path_dict[file_name] + r"\ChannelReferences\1")
        set_key_value(key_path_dict[file_name] + r"\ChannelReferences\1", [
            ("Flags", 0x0, winreg.REG_DWORD),
            ("Id", 0x11, winreg.REG_DWORD),
            ("(Default)", f"{provider_dict[file_name]}/Analytic", winreg.REG_SZ)
        ])

def main():
    parser = argparse.ArgumentParser(description="Register etw to provided driver path dlls")
    parser.add_argument(
        'driver_path',
        nargs='?',
        type=str,
        help='path to driver package containing dll/sys files (drivers/x64)'
    )

    driver_path = parser.parse_args().driver_path

    # Fall back to relative path
    if not driver_path:
        driver_path = os.path.abspath(os.path.join(
                    os.path.dirname(__file__), "..", "..", "drivers", "x64"))
        print(f"No driver path provided.\n"
              f"Falling back to relative path {driver_path}\n")

    print(f"Driver package path: {driver_path}")

    for key in key_path_dict:
        file_path = os.path.join(driver_path, key)
        if os.path.exists(file_path):
            register_file(key, file_path)
        else:
            print(f"ERROR: {file_path} not found")

if __name__ == "__main__":
    main()
