# Author: Gerald Hoehn, gerald@monstrous-moonshine.de, October 2023.

from mmgroup import *
from concurrent.futures import ProcessPoolExecutor
from classes import li


p = 255
V = MMV(p)

single = set([1,11,17,19,23,25,27,29,31,34,38,41,44,45,47,48,50,51,54,55,57,59,62,68,69,71,87,88,92,93,94,95,104,105,110,119])  # orders where there is only one class

double = set([8,24,28,32,42,84])     # orders where we need power 2 # can remove 12, 30, 60, for 32, 84 need only g^2
triple = set([12,24,30,60])          # orders where we need power 3


signature = []
orders = []
for x in li:
    signature.append((x[1],x[2],x[3],x[4]))
    orders.append(x[1])


def partial_trace(start, end, g):
    return sum([(V([('E', i)]) * g)['E'][i] for i in range(start, end)])


def trace_parallel(g, n_threads):
    chunk_size = 196884 // n_threads
    futures = []

    with ProcessPoolExecutor(max_workers=n_threads) as executor:
        for i in range(0, n_threads):
            start_index = i * chunk_size
            end_index = start_index + chunk_size if i < n_threads - 1 else 196884
            futures.append(executor.submit(partial_trace, start_index, end_index, g))

        return sum(future.result() for future in futures)


def trace(g):
    if g.in_G_x0():
        tr = g.chi_G_x0()[0]+1
    else:
       order, chi, h = g.chi_powers(1,100)
       if chi[1] is not None:
            tr = chi[1]+1
            print('use of chi_power')
       else:
            tr = trace_parallel(g, 28)
    return (tr % p)


def con_class(g):
    x = g.order()
    if x in single:
        i = orders.index(x)
    else:
        tr = trace(g)
        if x in double:
            tr2 = trace(g**2)
        else:
            tr2 = 0
        if x in triple:
            tr3 = trace(g**3)
        else:
            tr3 = 0
        i = signature.index((x,tr % p, tr2 % p, tr3 % p))
    return i


def class_name(i):
    return li[i][5]

def class_rep(i):
    return MM(li[i][0])


if __name__ == '__main__':

    for i in range(171):
        print(i+1,' : ', class_name(i), i == con_class(class_rep(i)), ' , ', signature[i], li[i][5], class_rep(i).in_G_x0())

    h = MM('r','M')
    for i in range(171):
        print(i+1,' : ', class_name(i), i == con_class(class_rep(i)**h), ' , ', signature[i], (class_rep(i)**h).in_G_x0())

'''    for n in range(171):
        print('----', n)
        for _ in range(2):
            h = MM('r','M')
            g=class_rep(n+1)**h
            order, chi, f = g.chi_powers(1,100)
            if chi[1] is None:
                print('fail')
'''
