/* Copyright Julian D. A. Wiseman 2005. This version 20:00 Thursday 17 November 2005. */
/* This program may be distributed under the terms of the GNU General Public License, */
/* currently available at www.gnu.org/licenses/gpl.txt. It is requested, but not      */
/* required, that authors of modifcations or improvements inform the original author  */
/* of these, contactable via www.jdawiseman.com/papers/easymath/coin-stopping.html.   */

/* Flip a coin repeatedly, stopping any time after the first toss. */
/* Score the proportion of throws that are heads: h/(h+t)          */
/* With an optimal strategy, what is the expected score?           */
/* Answer: 0.792953506407...                                       */

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <limits.h>
#include <float.h>

typedef long double realtype;
#define epsilonrealtype (LDBL_EPSILON)
#define longtypeprintingcode "%0.55Lf"

/* defines the range over which h[t] is output */
#define subsetminabsolute 0
#define subsetmaxabsolute 2048
#define subsetminproportionate 0
#define subsetmaxproportionate 0.25
#define subsetmodulus 1

#define calculatestatistics false

/* There is an occasional print to the screen, to confirm process is alive, this often */
#define logfrequency 1073741824 /* 2^30 */

int main()
{
	FILE *fp1, *fp2, *fp3;
	int i, start_n;
	register signed long t,h;
	signed long max_t, max_h, count=0;
	realtype *ev, d;

	float multiplier_n;
	realtype prevev00=0.75, prevprevev00=0.5;  /* Any values would do, though 0.75 and 0.5 are natural choices */
	char filename1[127], filename2[127], filename3[127] ;

#if calculatestatistics
	realtype *pr;  /* probability of hitting max_t boundary */
	realtype prevpr00 = 1;  /* Any value would do, though 1 is a natural choice */
	realtype *hb;  /* expected value of h conditional on hitting max_t boundary */
	realtype *hh;  /* expected value of h^2 conditional on hitting max_t boundary */
#endif


	printf("sizeof(float)=%lu   sizeof(double)=%lu   sizeof(long double)=%lu   used size=%lu\n", sizeof(float), sizeof(double), sizeof(long double), sizeof(realtype) );

	do {
		printf("\nStarting value of n, >=2? ");
		scanf ("%d",&start_n);
	} while( start_n < 2 );
	do {
		printf("\nEach loop n multiplies by a number >= %f. What number? ", (float)(1.0+(1.0/((float)start_n))) );
		scanf ("%f",&multiplier_n);
	} while( ((int)(start_n * multiplier_n)) <= start_n  );
	printf("start_n = %i        multiplier_n = %f\n", start_n, multiplier_n);

	sprintf(filename1, "coin_%i_%f_scores.txt",    start_n, multiplier_n);
	sprintf(filename2, "coin_%i_%f_strategy.txt",  start_n, multiplier_n);
	sprintf(filename3, "coin_%i_%f_precision.txt", start_n, multiplier_n);

	printf("Opening '%s'.\n",filename1);
	if(NULL==(fp1=fopen(filename1, "w")))
		{printf("fopen problem with '%s'.",filename1); fprintf(stderr,"fopen problem with '%s'.",filename1); exit(EXIT_FAILURE);}

	printf("Opening '%s'.\n",filename2);
	if(NULL==(fp2=fopen(filename2, "w")))
		{printf("fopen problem with '%s'.",filename2); fprintf(stderr,"fopen problem with '%s'.",filename2); exit(EXIT_FAILURE);}

	printf("Opening '%s'.\n",filename3);
	if(NULL==(fp3=fopen(filename3, "w")))
		{printf("fopen problem with '%s'.",filename3); fprintf(stderr,"fopen problem with '%s'.",filename3); exit(EXIT_FAILURE);}

	fprintf(fp1,  /* scores */
				"i"
		"\t"	"max_t"
		"\t"	"EV"
		"\t"	"Diff"
		"\t"	"RatioDiffs"
		"\t"	"EstimatedLimit"
#if calculatestatistics
		"\t"	"ProbReachRightBoundary"
		"\t"	"Ratio"
		"\t"	"ExpectedhBoundary"
		"\t"	"ExpectedhhBoundary"
#endif
		"\n"
	);
	fprintf(fp2,  /* strategy */
		    	"i"
		"\t"	"max_t"
		"\t"	"t"
		"\t"	"h"
		"\t"	"ev[h,t]"
		"\n"
	);
	fprintf(fp3,  /* precision notes */
				"i"
		"\t"	"max_t"
		"\t"	"t"
		"\t"	"h"
		"\n"
	);

	max_t = start_n ;
	i = 1;  /* i used only to label output rows */
	while( i < INT_MAX )
	{
		max_h = max_t - 1;

		if( NULL == (ev = malloc( (1+max_h) * sizeof(realtype) ) ) )
		{
			printf("malloc problem with ev"); fprintf(stderr,"malloc problem with ev");
			fclose(fp1); fclose(fp2); fclose(fp3);  exit(EXIT_SUCCESS);
		}
#if calculatestatistics
		if( NULL == (pr = malloc( (1+max_h) * sizeof(realtype) ) ) )
		{
			printf("malloc problem with pr"); fprintf(stderr,"malloc problem with pr");
			fclose(fp1); fclose(fp2); fclose(fp3);  free(ev);  exit(EXIT_SUCCESS);
		}
		if( NULL == (hb = malloc( (1+max_h) * sizeof(realtype) ) ) )
		{
			printf("malloc problem with hb"); fprintf(stderr,"malloc problem with hb");
			fclose(fp1); fclose(fp2); fclose(fp3);  free(ev); free(pr);  exit(EXIT_SUCCESS);
		}
		if( NULL == (hh = malloc( (1+max_h) * sizeof(realtype) ) ) )
		{
			printf("malloc problem with hh"); fprintf(stderr,"malloc problem with hh");
			fclose(fp1); fclose(fp2); fclose(fp3);  free(ev); free(pr); free(hb);  exit(EXIT_SUCCESS);
		}
#endif

		/* initialise rightmost column, that being t==max_t */
		t = max_t;
		for( h=max_h ; h>t ; h-- )
		{
			ev[h] = ((realtype)h) / (h+t); 
#if calculatestatistics
			pr[h] = ((realtype)0);
			hb[h] = ((realtype)h);
			hh[h] = ((realtype)h)*((realtype)h);
#endif
		}
		for( ; h>=0 ; h-- )
		{
			ev[h] = ((realtype)0.5);
#if calculatestatistics
			pr[h] = ((realtype)1);
			hb[h] = ((realtype)h);
			hh[h] = ((realtype)h)*((realtype)h);
#endif
		}

		/* Now compute all columns except the rightmost, being t<max_t */
		for( t-- ; t>=0 ; t-- )
		{
			/* last row */
			h=max_h;  /* which is not less than t */
			ev[h] = ( ev[h] + ( (realtype)(h+1) / ( h+1 + t )) ) / 2;
#if calculatestatistics
			pr[h] = pr[h] / 2;
			/* hb[h] = hb[h]; */
#endif
			d = ((realtype)h) / (h+t); 
			if( ev[h] < d )
			{
				ev[h] = d;
#if calculatestatistics
				pr[h]=0;  /* hb[h] is irrelevant */
#endif
			}

/* Several times might or might not need to do this calculation, which is therefore put in a #define */
#if calculatestatistics
	#define calc_statistics { d = pr[h] + pr[h+1]; \
		if( d > 0 )  \
		{ \
			hb[h] = ( pr[h]*hb[h] + pr[h+1]*hb[h+1] ) / d; \
			hh[h] = ( pr[h]*hh[h] + pr[h+1]*hh[h+1] ) / d; \
		} \
		pr[h] = d / 2;}
#else
	#define calc_statistics  
#endif

			/* rows between the last and first */
			count += h;  /* We add the full amount to go, and then subtract those undone, calculating count without having it in the innermost loop */
			for( h-- ; h>0 ; h-- )  /* testing whether to stop */
			{
				ev[h] = ( ev[h+1] + ev[h] ) / 2;
				d = ((realtype)h) / (h+t); 
				if( ev[h] < d )
				{
					ev[h] = d;  /* Optimal strategy requires that a player stops */
#if calculatestatistics
					pr[h]=0;  /* hb[h] is irrelevant */
#endif
				}
				else
				{   /* So no need to test for any h from here down to 1 */
					if( (t % subsetmodulus == 0)
					&& (t<=subsetmaxabsolute) && (t<=subsetmaxproportionate*max_t) 
					&& (t>=subsetminabsolute) && (t>=subsetminproportionate*max_t) )
						fprintf(fp2,
							"%i" "\t" "%lu" "\t" "%li" "\t" "%li" "\t" longtypeprintingcode "\n",
							i,
							max_t,
							t,
							h,
							ev[h]
						);
					if( max_h > h+1 ) /* reset because don't need to test the larger h's for smaller t, */
						max_h = h+1;  /* ... as they will all be stopping states. */

					calc_statistics ;

					break;  /* break this h loop, as no further need to test stopping condition */
				}
			}  /* end for( h ... ) */

			/* Know that the optimal strategy for all h from here to 0 is to toss again. */
			/* Infrequently, test whether machine precision limit has been reached, */
			/* and if it has, leave ev(t,h) equal to ev(t+1,h) for all smaller h */
			for( ; h>16 ; )  /* testing every 16 for speed */
			{
				h--;   ev[h] = ( ev[h+1] + ev[h] ) / 2;   calc_statistics
				h--;   ev[h] = ( ev[h+1] + ev[h] ) / 2;   calc_statistics
				h--;   ev[h] = ( ev[h+1] + ev[h] ) / 2;   calc_statistics
				h--;   ev[h] = ( ev[h+1] + ev[h] ) / 2;   calc_statistics

				h--;   ev[h] = ( ev[h+1] + ev[h] ) / 2;   calc_statistics
				h--;   ev[h] = ( ev[h+1] + ev[h] ) / 2;   calc_statistics
				h--;   ev[h] = ( ev[h+1] + ev[h] ) / 2;   calc_statistics
				h--;   ev[h] = ( ev[h+1] + ev[h] ) / 2;   calc_statistics

				h--;   ev[h] = ( ev[h+1] + ev[h] ) / 2;   calc_statistics
				h--;   ev[h] = ( ev[h+1] + ev[h] ) / 2;   calc_statistics
				h--;   ev[h] = ( ev[h+1] + ev[h] ) / 2;   calc_statistics
				h--;   ev[h] = ( ev[h+1] + ev[h] ) / 2;   calc_statistics

				h--;   ev[h] = ( ev[h+1] + ev[h] ) / 2;   calc_statistics
				h--;   ev[h] = ( ev[h+1] + ev[h] ) / 2;   calc_statistics
				h--;   ev[h] = ( ev[h+1] + ev[h] ) / 2;   calc_statistics
				h--;   ev[h] = ( ev[h+1] + ev[h] ) / 2;   calc_statistics

				if( ev[h] <= ((realtype)0.5) + epsilonrealtype ) /* Machine precision limit, so no */
					goto machine_precision_reached;              /* need to update for smaller h */
			}  /* end second for( h ... ) */

			for( ; h>1 ; )  /* last few non-stopping cases, except top row */
				{h--;   ev[h] = ( ev[h+1] + ev[h] ) / 2;   calc_statistics}
			
			/* top row done here to avoid possible divide-by-zero */ 
			h=0;   ev[0] = ( ev[1] + ev[0] ) / 2;   calc_statistics
#undef calc_statistics
			machine_precision_reached: ;

			count -= h;  /* Number of rows done. this being done outside h loop for speed. */
			/* Possibly output something to screen, to confirm that the program is still alive. */
			if( count >= logfrequency )
			{
				count=0;
				printf("i=%i     t=%li     reached machine precision with h >= %li\n",i,t,h);
				fprintf(fp3, "%i" "\t" "%li" "\t" "%li" "\t" "%li" "\n", i, max_t, t, h);
			}
		} /* for( t ... ) */

		printf( longtypeprintingcode "\t" "%i" "\t" "%li" "\n",
			ev[0],
			i,
			max_t
		);
		fprintf( fp1,
			"%i\t%li\t" longtypeprintingcode "\t" longtypeprintingcode "\t" longtypeprintingcode "\t" longtypeprintingcode,
			i,
			max_t,
			ev[0],
			ev[0]-prevev00,
			(ev[0]-prevev00)/(prevev00-prevprevev00),
			ev[0] + (ev[0]-prevev00) * (prevev00-ev[0])/(ev[0] + prevprevev00 - 2*prevev00)
		);
#if calculatestatistics
		fprintf( fp1, "\t" longtypeprintingcode "\t" longtypeprintingcode "\t" longtypeprintingcode "\t" longtypeprintingcode, 
			pr[0],
			pr[0] / prevpr00,
			hb[0],
			hh[0]
		);
		prevpr00 = pr[0];
#endif
		fprintf( fp1, "\n" );
		fflush(fp1);
		fflush(fp2);

		prevprevev00 = prevev00;
		prevev00 = ev[0];

		free(ev);
#if calculatestatistics
		free(pr);
		free(hb);
		free(hh);
#endif
		
		if( max_t >= 1 + (long)(LONG_MAX / multiplier_n) )
			break; /* max_t would be too big in next pass through loop  */
		max_t *= multiplier_n ;
		i++;
	} /* while( i < INT_MAX ) */

	fclose(fp1);
	fclose(fp2);
	fclose(fp3);
	exit(EXIT_SUCCESS);
}
