from ..libmp.backend import xrange from .functions import defun, defun_wrapped @defun def gammaprod(ctx, a, b, _infsign=False): a = [ctx.convert(x) for x in a] b = [ctx.convert(x) for x in b] poles_num = [] poles_den = [] regular_num = [] regular_den = [] for x in a: [regular_num, poles_num][ctx.isnpint(x)].append(x) for x in b: [regular_den, poles_den][ctx.isnpint(x)].append(x) # One more pole in numerator or denominator gives 0 or inf if len(poles_num) < len(poles_den): return ctx.zero if len(poles_num) > len(poles_den): # Get correct sign of infinity for x+h, h -> 0 from above # XXX: hack, this should be done properly if _infsign: a = [x and x*(1+ctx.eps) or x+ctx.eps for x in poles_num] b = [x and x*(1+ctx.eps) or x+ctx.eps for x in poles_den] return ctx.sign(ctx.gammaprod(a+regular_num,b+regular_den)) * ctx.inf else: return ctx.inf # All poles cancel # lim G(i)/G(j) = (-1)**(i+j) * gamma(1-j) / gamma(1-i) p = ctx.one orig = ctx.prec try: ctx.prec = orig + 15 while poles_num: i = poles_num.pop() j = poles_den.pop() p *= (-1)**(i+j) * ctx.gamma(1-j) / ctx.gamma(1-i) for x in regular_num: p *= ctx.gamma(x) for x in regular_den: p /= ctx.gamma(x) finally: ctx.prec = orig return +p @defun def beta(ctx, x, y): x = ctx.convert(x) y = ctx.convert(y) if ctx.isinf(y): x, y = y, x if ctx.isinf(x): if x == ctx.inf and not ctx._im(y): if y == ctx.ninf: return ctx.nan if y > 0: return ctx.zero if ctx.isint(y): return ctx.nan if y < 0: return ctx.sign(ctx.gamma(y)) * ctx.inf return ctx.nan xy = ctx.fadd(x, y, prec=2*ctx.prec) return ctx.gammaprod([x, y], [xy]) @defun def binomial(ctx, n, k): n1 = ctx.fadd(n, 1, prec=2*ctx.prec) k1 = ctx.fadd(k, 1, prec=2*ctx.prec) nk1 = ctx.fsub(n1, k, prec=2*ctx.prec) return ctx.gammaprod([n1], [k1, nk1]) @defun def rf(ctx, x, n): xn = ctx.fadd(x, n, prec=2*ctx.prec) return ctx.gammaprod([xn], [x]) @defun def ff(ctx, x, n): x1 = ctx.fadd(x, 1, prec=2*ctx.prec) xn1 = ctx.fadd(ctx.fsub(x, n, prec=2*ctx.prec), 1, prec=2*ctx.prec) return ctx.gammaprod([x1], [xn1]) @defun_wrapped def fac2(ctx, x): if ctx.isinf(x): if x == ctx.inf: return x return ctx.nan return 2**(x/2)*(ctx.pi/2)**((ctx.cospi(x)-1)/4)*ctx.gamma(x/2+1) @defun_wrapped def barnesg(ctx, z): if ctx.isinf(z): if z == ctx.inf: return z return ctx.nan if ctx.isnan(z): return z if (not ctx._im(z)) and ctx._re(z) <= 0 and ctx.isint(ctx._re(z)): return z*0 # Account for size (would not be needed if computing log(G)) if abs(z) > 5: ctx.dps += 2*ctx.log(abs(z),2) # Reflection formula if ctx.re(z) < -ctx.dps: w = 1-z pi2 = 2*ctx.pi u = ctx.expjpi(2*w) v = ctx.j*ctx.pi/12 - ctx.j*ctx.pi*w**2/2 + w*ctx.ln(1-u) - \ ctx.j*ctx.polylog(2, u)/pi2 v = ctx.barnesg(2-z)*ctx.exp(v)/pi2**w if ctx._is_real_type(z): v = ctx._re(v) return v # Estimate terms for asymptotic expansion # TODO: fixme, obviously N = ctx.dps // 2 + 5 G = 1 while abs(z) < N or ctx.re(z) < 1: G /= ctx.gamma(z) z += 1 z -= 1 s = ctx.mpf(1)/12 s -= ctx.log(ctx.glaisher) s += z*ctx.log(2*ctx.pi)/2 s += (z**2/2-ctx.mpf(1)/12)*ctx.log(z) s -= 3*z**2/4 z2k = z2 = z**2 for k in xrange(1, N+1): t = ctx.bernoulli(2*k+2) / (4*k*(k+1)*z2k) if abs(t) < ctx.eps: #print k, N # check how many terms were needed break z2k *= z2 s += t #if k == N: # print "warning: series for barnesg failed to converge", ctx.dps return G*ctx.exp(s) @defun def superfac(ctx, z): return ctx.barnesg(z+2) @defun_wrapped def hyperfac(ctx, z): # XXX: estimate needed extra bits accurately if z == ctx.inf: return z if abs(z) > 5: extra = 4*int(ctx.log(abs(z),2)) else: extra = 0 ctx.prec += extra if not ctx._im(z) and ctx._re(z) < 0 and ctx.isint(ctx._re(z)): n = int(ctx.re(z)) h = ctx.hyperfac(-n-1) if ((n+1)//2) & 1: h = -h if ctx._is_complex_type(z): return h + 0j return h zp1 = z+1 # Wrong branch cut #v = ctx.gamma(zp1)**z #ctx.prec -= extra #return v / ctx.barnesg(zp1) v = ctx.exp(z*ctx.loggamma(zp1)) ctx.prec -= extra return v / ctx.barnesg(zp1) @defun_wrapped def loggamma_old(ctx, z): a = ctx._re(z) b = ctx._im(z) if not b and a > 0: return ctx.ln(ctx.gamma_old(z)) u = ctx.arg(z) w = ctx.ln(ctx.gamma_old(z)) if b: gi = -b - u/2 + a*u + b*ctx.ln(abs(z)) n = ctx.floor((gi-ctx._im(w))/(2*ctx.pi)+0.5) * (2*ctx.pi) return w + n*ctx.j elif a < 0: n = int(ctx.floor(a)) w += (n-(n%2))*ctx.pi*ctx.j return w ''' @defun def psi0(ctx, z): """Shortcut for psi(0,z) (the digamma function)""" return ctx.psi(0, z) @defun def psi1(ctx, z): """Shortcut for psi(1,z) (the trigamma function)""" return ctx.psi(1, z) @defun def psi2(ctx, z): """Shortcut for psi(2,z) (the tetragamma function)""" return ctx.psi(2, z) @defun def psi3(ctx, z): """Shortcut for psi(3,z) (the pentagamma function)""" return ctx.psi(3, z) '''