#!/usr/bin/env python3 import os import sys import re def cstring(x): return '"%s"' % x.replace('\\','\\\\').replace('"','\\"').replace('\n','\\n') def sanitize(x): return ''.join(c if c in '0123456789abcdefghijklmnopqrstuvwxyz' else '_' for c in x) operations = [] primitives = {} sizes = {} exports = {} prototypes = {} with open('api') as f: for line in f: line = line.strip() if line.startswith('crypto_'): line = line.split('/') assert len(line) == 2 o = line[0].split('_')[1] if o not in operations: operations += [o] p = line[1] primitives[o] = p continue if line.startswith('#define '): x = line.split(' ') x = x[1].split('_') assert len(x) == 3 assert x[0] == 'crypto' o = x[1] if o not in sizes: sizes[o] = '' sizes[o] += line+'\n' continue if line.endswith(');'): fun,args = line[:-2].split('(') rettype,fun = fun.split() fun = fun.split('_') o = fun[1] assert fun[0] == 'crypto' if o not in exports: exports[o] = [] exports[o] += ['_'.join(fun[1:])] if o not in prototypes: prototypes[o] = [] prototypes[o] += [(rettype,fun,args)] goal = sys.argv[1] assert goal in ('auto','manual') o = sys.argv[2] host = sys.argv[3] impls = [] for line in sys.stdin: line = line.strip().split('/') if line[0] != o: continue impls += [line[1:]] icarch = {} iccompiler = {} for i,c in impls: with open('compilerarch/%s' % c) as f: icarch[i,c] = f.read().strip() with open('compilerversion/%s' % c) as f: iccompiler[i,c] = f.read().strip() def archkey(a): if a == 'default': return 1,a # put default last return -a.count('+'),a allimpls = sorted(set(i for i,c in impls)) allarches = sorted(set(icarch[i,c] for i,c in impls),key=archkey) if goal == 'auto': prioritydata = [] for i in allimpls: priorityfn = 'priority/%s-%s' % (o,i) if not os.path.exists(priorityfn): continue with open(priorityfn) as f: for line in f: line = line.split() if len(line) < 7: continue prio,score,priohost,cpuid,version,machine = line[:6] c = ' '.join(line[6:]) prio = float(prio) prioritydata += [(i,prio,score,priohost,cpuid,machine,c)] def asupportsic(a,i,c): a = a.split('+')[1:] ica = icarch[i,c] ica = ica.split('+')[1:] return all(icapart in a for icapart in ica) def cpuidsupports(cpuid,a): a = a.split('+') cpuid = [int('0x'+cpuid[8*j:8*j+8],16) for j in range(32)] mmx = cpuid[18] & (1<<23) sse = cpuid[18] & (1<<25) sse2 = cpuid[18] & (1<<26) sse3 = cpuid[17] & (1<<0) ssse3 = cpuid[17] & (1<<9) sse41 = cpuid[17] & (1<<19) sse42 = cpuid[17] & (1<<20) osxsave = cpuid[17] & (1<<27) avx = cpuid[17] & (1<<28) bmi1 = cpuid[20] & (1<<3) avx2 = cpuid[20] & (1<<5) bmi2 = cpuid[20] & (1<<8) avx512f = cpuid[20] & (1<<16) adx = cpuid[20] & (1<<19) avx512ifma = cpuid[20] & (1<<21) avx512vl = cpuid[20] & (1<<31) xmmsaved = cpuid[27] & (1<<1) ymmsaved = cpuid[27] & (1<<2) for apart in a[1:]: if apart not in ('adx','avx','bmi2','avx2','avx512f','avx512vl','avx512ifma'): raise ValueError('cpuidsupports does not understand %s' % apart) if apart == 'avx512f': if not avx512f: return False if apart == 'avx512vl': if not avx512vl: return False if apart == 'avx512ifma': if not avx512ifma: return False if apart == 'bmi2': if not bmi1: return False if not bmi2: return False if apart == 'adx': if not adx: return False if apart == 'avx2': if not avx2: return False if apart in ('avx','avx2'): if not avx: return False if not mmx: return False if not sse: return False if not sse2: return False if not sse3: return False if not ssse3: return False if not sse41: return False if not sse42: return False if not osxsave: return False if not xmmsaved: return False if not ymmsaved: return False return True def selectic(a,aexclude): if len(aexclude) > 0: print('/* considering other machines supporting %s */' % a) else: print('/* considering machines supporting %s */' % a) # requirement: icarch[i,c] is a subset of a compatibleimpls = [(i,c) for i,c in impls if asupportsic(a,i,c)] assert len(compatibleimpls) > 0 # desideratum: good performance based on prioritydata directmatches = any( priohost == host and cpuidsupports(cpuid,a) and all(not cpuidsupports(cpuid,b) for b in aexclude) for i,prio,score,priohost,cpuid,machine,c in prioritydata ) if not directmatches: print('/* no direct matches, so extrapolating from all machines */') totalprio = {(i,c):0 for i,c in compatibleimpls} totalweight = {(i,c):0 for i,c in compatibleimpls} for prioi,prio,score,priohost,cpuid,machine,prioc in prioritydata: if directmatches: if priohost != host: continue if any(cpuidsupports(cpuid,b) for b in aexclude): continue if not cpuidsupports(cpuid,a): continue for i,c in compatibleimpls: if i != prioi: continue # XXX: use more serious machine learning here weight = 1.0 if priohost == host: weight *= 10 if cpuidsupports(cpuid,a): weight *= 10 if all(not cpuidsupports(cpuid,b) for b in aexclude): weight *= 10 weight *= 1+len(os.path.commonprefix([iccompiler[i,c],prioc])) if iccompiler[i,c] == prioc: weight *= 10 # print('/* weight %s from %s %s %s %s for %s %s */' % (weight,prio,machine,prioi,prioc,i,c)) totalprio[i,c] += prio*weight totalweight[i,c] += weight # note that implementations without priority data are excluded from ranking ranking = [(totalprio[i,c]/totalweight[i,c],i,c) for i,c in compatibleimpls if totalweight[i,c] > 0] ranking.sort() for prio,i,c in ranking: print('/* priority %s for %s %s */' % (prio,i,c)) if len(ranking) == 0: return compatibleimpls[0] return ranking[0][1:] todo = [] usedimpls = set() handledarches = set() for a in allarches: i,c = selectic(a,handledarches) usedimpls.add((i,c)) todo += [(a,i,c)] handledarches.add(a) for a,i,c in todo: print('/* decision: for %s use %s %s */' % (a,i,c)) print('') if goal == 'auto': print('extern const char *lib25519_%s_implementation(void) __attribute__((visibility("default")));' % o) print('extern const char *lib25519_%s_compiler(void) __attribute__((visibility("default")));' % o) else: print('extern const char *lib25519_%s_implementation(void);' % o) print('extern const char *lib25519_%s_compiler(void);' % o) print('extern const char *lib25519_dispatch_%s_implementation(long long) __attribute__((visibility("default")));' % o) print('extern const char *lib25519_dispatch_%s_compiler(long long) __attribute__((visibility("default")));' % o) print('extern long long lib25519_numimpl_%s(void) __attribute__((visibility("default")));' % o) for a in allarches: if a == 'default': continue a_csymbol = sanitize(a) print('extern int lib25519_supports_%s(void);' % a_csymbol) if len(allarches) > 1: print('') def printfun_auto(which,fun=None): if which == 'resolver': shortfun = '_'.join(fun[1:]) print('void *lib25519_auto_%s(void)' % shortfun) elif which == 'implementation': print('const char *lib25519_%s_implementation(void)' % o) elif which == 'compiler': print('const char *lib25519_%s_compiler(void)' % o) else: raise ValueError('unknown printfun %s' % which) print('{') for a,i,c in todo: cond = '' if a != 'default': cond = 'if (lib25519_supports_%s()) ' % sanitize(a) if which == 'resolver': print(' %sreturn lib25519_%s_%s_%s_%s;' % (cond,o,sanitize(i),c,shortfun)) if which == 'implementation': print(' %sreturn %s;' % (cond,cstring(i))) if which == 'compiler': print(' %sreturn %s;' % (cond,cstring(iccompiler[i,c]))) if a == 'default': break print('}') if which == 'resolver': print('') print('%s lib25519_%s(%s) __attribute__((visibility("default"))) __attribute__((ifunc("lib25519_auto_%s")));' % (rettype,shortfun,args,shortfun)) for rettype,fun,args in prototypes[o]: shortfun = '_'.join(fun[1:]) if goal == 'auto': print('extern %s lib25519_%s(%s) __attribute__((visibility("default")));' % (rettype,shortfun,args)) else: print('extern %s lib25519_%s(%s);' % (rettype,shortfun,args)) print('extern %s (*lib25519_dispatch_%s(long long))(%s) __attribute__((visibility("default")));' % (rettype,shortfun,args)) print('') for i,c in impls: if goal == 'auto': if (i,c) not in usedimpls: continue print('extern %s lib25519_%s_%s_%s_%s(%s) __attribute__((visibility("default")));' % (rettype,o,sanitize(i),c,shortfun,args)) print('') if goal == 'auto': printfun_auto('resolver',fun) if goal == 'manual': namedparams = args.split(',') for i in range(len(namedparams)): if namedparams[i][-1] != '*': namedparams[i] += ' ' namedparams[i] += 'arg%d'%i namedparams = ','.join(namedparams) print('%s (*lib25519_dispatch_%s(long long impl))(%s)' % (rettype,shortfun,args)) print('{') for a in allarches: if a == 'default': continue a_csymbol = sanitize(a) print(' int supports_%s = lib25519_supports_%s();' % (a_csymbol,a_csymbol)) print(' if (impl >= 0) {') for i,c in impls: a = icarch[i,c] a_csymbol = sanitize(a) if a == 'default': print(' if (!impl--) return lib25519_%s_%s_%s_%s;' % (o,sanitize(i),c,shortfun)) else: print(' if (supports_%s) if (!impl--) return lib25519_%s_%s_%s_%s;' % (a_csymbol,o,sanitize(i),c,shortfun)) print(' }') print(' return lib25519_%s;' % shortfun) print('}') print('') if goal == 'auto': printfun_auto('implementation') print('') printfun_auto('compiler') else: print('const char *lib25519_dispatch_%s_implementation(long long impl)' % o) print('{') for a in allarches: if a == 'default': continue a_csymbol = sanitize(a) print(' int supports_%s = lib25519_supports_%s();' % (a_csymbol,a_csymbol)) print(' if (impl >= 0) {') for i,c in impls: a = icarch[i,c] a_csymbol = sanitize(a) if a == 'default': print(' if (!impl--) return %s;' % (cstring(i))) else: print(' if (supports_%s) if (!impl--) return %s;' % (a_csymbol,cstring(i))) print(' }') print(' return lib25519_%s_implementation();' % o) print('}') print('') print('const char *lib25519_dispatch_%s_compiler(long long impl)' % o) print('{') for a in allarches: if a == 'default': continue a_csymbol = sanitize(a) print(' int supports_%s = lib25519_supports_%s();' % (a_csymbol,a_csymbol)) print(' if (impl >= 0) {') for i,c in impls: a = icarch[i,c] a_csymbol = sanitize(a) if a == 'default': print(' if (!impl--) return %s;' % (cstring(iccompiler[i,c]))) else: print(' if (supports_%s) if (!impl--) return %s;' % (a_csymbol,cstring(iccompiler[i,c]))) print(' }') print(' return lib25519_%s_compiler();' % o) print('}') print('') print('long long lib25519_numimpl_%s(void)' % o) print('{') numimpla = sum(1 for (i,c) in impls if icarch[i,c] == 'default') numimpl = ['%d' % numimpla] for a in allarches: if a == 'default': continue a_csymbol = sanitize(a) print(' long long supports_%s = lib25519_supports_%s();' % (a_csymbol,a_csymbol)) numimpla = sum(1 for (i,c) in impls if icarch[i,c] == a) numimpl += ['supports_%s*%d' % (a_csymbol,numimpla)] print(' return %s;' % '+'.join(numimpl)) print('}')