#! /usr/bin/env python

"""N-bit Feistel algorithm.

This is an implementation of arbitrary block length Feistel algorithm.
It encrypts and decrypts (long) integers.  The key is a string.  F()
is derived with SHA-1 from the other half (as usual), round number,
and the key string.

There is also now another algorithm where the mapping is in
[0 .. n*n-1].  n can be an arbitrary integer.

Discussion on this: <URL:http://groups-beta.google.com/group/sci.crypt/browse_frm/thread/cd95584c9440be08/4cc0a5fd3dd179ba>

Copyright 2005 Antti Louko <alo@louko.com>
"""

__version__ = "0.4"

import sys
import getopt
import os
import sha
import array

try:
    import psyco
    psyco.full()
except:
    # print 'Psyco not found, ignoring it'
    pass

def usage(utyp, *msg):
    sys.stderr.write('Usage: %s\n' % os.path.split(sys.argv[0])[1])
    if msg:
        sys.stderr.write('Error: %s\n' % `msg`)
    sys.exit(1)

def shas(s):
    return sha.new(s).digest()

def digest2l(d):
    # This converts 20-byte digest to array of long integers
    return array.array('L',d)

class Ffun:
    def __init__(self,bits,key,m=0):
        self.bits = bits
        self.mask = (1L << bits) - 1
        self.key = key
        self.m = m
    def fun(self,r,i):
        if self.bits <= 32:
            # We take just the first 32-bit element
            res = self.mask & digest2l(shas('%02x:%02x:%s' % (r,i,self.key)))[0]
            if self.m:
                res = res % self.m
            return res
        else:
            # We have to combine 32-bit elements into a self.bits-length integer
            bits = self.bits
            res = 0L
            c = 0
            a = []
            while bits > 0:
                if not a:
                    a = digest2l(shas('%02x:%02x:%02x:%s' % (r,i,c,self.key)))
                    c += 1
                x = a.pop(0)
                res = (res << 32) | x
                bits -= 32
            res &= self.mask
            if self.m:
                res = res % self.m
            return res
        
class Feistel:
    def __init__(self,key,bits=64,rounds=7):
        self.key = key
        self.bits = bits
        self.rounds = rounds
        self.dmask = (1L << bits) - 1
        self.bits2 = bits >> 1
        if self.bits2 * 2 != bits:
            raise Exception('bits must be even')
        self.dmask2 = (1L << self.bits2) - 1
        self.fun = Ffun(self.bits2,key).fun
        return
    def round2(self,l0,r0,fun,i):
        l1 = r0
        r1 = l0 ^ fun(r0,i)
        return (l1,r1)
    def encrypt(self,d):
        d &= self.dmask
        l = d >> self.bits2
        r = d & self.dmask2
        for i in xrange(self.rounds):
            l,r = self.round2(l,r,self.fun,i)
        c = l | (r << self.bits2)
        return c
    def decrypt(self,d):
        d &= self.dmask
        l = d >> self.bits2
        r = d & self.dmask2
        for i in xrange(self.rounds-1,-1,-1):
            l,r = self.round2(l,r,self.fun,i)
        c = l | (r << self.bits2)
        return c
        
nbits0 = {'0':4,'1':3,'2':2,'3':2,'4':1,'5':1,'6':1,'7':1,'8':0,'9':0,'a':0,'b':0,'c':0,'d':0,'e':0,'f':0}
def nbits(n):
    s = '%x' % n 
    return len(s)*4-nbits0[s[0]]

class FeistelModn2:
    """This is an modular additive Feistel algorithm.
    It creates a permutation from [0..n**2-1] to [0..n**2-1]"""

    def __init__(self,key,n,rounds=7):
        self.key = key
        self.n = n
        self.bits2 = nbits(self.n)+4
        if self.bits2 < 32:
            self.bits2 = 32
        self.rounds = rounds
        self.dmod = self.n*self.n
        self.fun = Ffun(self.bits2,key,self.n).fun
        return
    def encrypt(self,d):
        d = d % self.dmod
        l,r = divmod(d,self.n)
        for i in xrange(self.rounds):
            l,r = (r,(l + self.fun(r,i)) % self.n)
        c = l + r*self.n
        return c
    def decrypt(self,d):
        d = d % self.dmod
        l,r = divmod(d,self.n)
        for i in xrange(self.rounds-1,-1,-1):
            l,r = (r,(l - (self.fun(r,i)%self.n)) % self.n)
        c = l + r*self.n
        return c

class Global:
    def __init__(gp):
        gp.vflag = 0
        gp.bits = 10
        gp.key = 'foo1'
        gp.n = 9
        return
    def doit(gp,args):
        fout = sys.stdout
        m = 1l << gp.bits
        fmt1 = '%%0%dx' % ((gp.bits+3) / 4)
        fmt = '%s %s %s' % (fmt1,fmt1,fmt1)
        f1 = Feistel(gp.key,gp.bits)
        nok = 0
        nfailed = 0
        for i in xrange(m):
            x = f1.encrypt(i)
            x2 = f1.decrypt(x)
            if x2 != i:
                nfailed = nfailed + 1
                fout.write('!!! %s\n' % (fmt % (i,x,x2)))
            else:
                if gp.vflag: fout.write('    %s\n' % (fmt % (i,x,x2)))
                nok = nok + 1
        fout.write('%d ok %d failed\n' % (nok,nfailed))
        fout.write('Demonstrating the modular feistel\n')
        n = gp.n
        fw = len('%d' % (n*n))
        fmt = '%%%dd --> %%%dd --> %%%dd\n' % (fw,fw,fw)
        f2 = FeistelModn2(gp.key,n)     # Permutation of [0 .. n*n-1]
        i = 0
        while 1:
            x1 = f2.encrypt(i)
            x2 = f2.decrypt(x1)
            fout.write(fmt % (i,x1,x2))
            i = i + 1
            if i >= n*n:
                break
        return

def main(argv):
    gp = Global()
    try:
        opts, args = getopt.getopt(sys.argv[1:],
                                   'hvf:c:n:',
                                   ['help',
                                    'verbose',
                                    'bits=',
                                    'key=',
                                    'n=',
                                    ])
    except getopt.error, msg:
        usage(1, msg)

    for opt, arg in opts:
        if opt in ('-h', '--help'):
            usage(0)
        if opt in ('-v', '--verbose'):
            gp.vflag = gp.vflag + 1
        elif opt in ('--bits',):
            gp.bits = int(arg)
        elif opt in ('--key',):
            gp.key = arg
        elif opt in ('-n', '--n',):
            gp.n = int(arg)

    gp.doit(args)
        
if __name__ == '__main__':
    main(sys.argv)
