/*****************************************/
/* Stack Shield v0.5                     */
/* by Vendicator 1999                    */
/*****************************************/
/* File: stackshield.c                   */
/* Stack Shield file protection program  */
/*****************************************/

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <unistd.h>

#include "globalretstack.h"

void usage(char *progname);
void parsefile(FILE *srcptr, FILE *destptr);
void doprolog(FILE *fileptr);
void doepilog(FILE *fileptr);
void domainprolog(FILE *fileptr);
void doheader(FILE *fileptr);
void dodetectattackepilog(FILE *fileptr);
void doanticrashprolog(FILE *fileptr);
void doanticrashepilog(FILE *fileptr);

int buffelements = 256;
int anticrash = 0;
int detectattack = 0;
char entrypoint[257];

int prologcount = 0;
int epilogcount = 0;
int buffsize;

void usage(char *progname) {
  fprintf(stderr, "Usage: %s [-cd][-l elements][-e entrypoint] srcfile destfile\n", progname);  
  fprintf(stderr, "Where:\n");
  fprintf(stderr, "srcfile is the input assembly file\n");
  fprintf(stderr, "destfile is the output assembly file\n");
  fprintf(stderr, "Options:\n");
  fprintf(stderr, "-c prevents runtime errors if more than too much nested calls are executed\n");
  fprintf(stderr, "-d terminate immidiately when a buffer overflow is detected\n");
  fprintf(stderr, "-l elements set buffer size to 'elements' number (elements are 4 bytes) (default 256)\n");
  fprintf(stderr, "-e entrypoint specify the program entry point symbol (default main)\n");

  exit(EXIT_FAILURE);
}

void doprolog(FILE *fileptr) {
  char label[257];
  char *block;
  char *buffer;
  char *tmp;

  block=(char *)calloc(strlen(grsprolog)+256, 1);
  buffer=(char *)calloc(strlen(grsprolog)+256, 1);
  strncpy(block, grsprolog, strlen(grsprolog));

  prologcount++;
  label[256]='\0';

  while (tmp=strstr(block, "<PROLOGCOUNT>")) {
    strncpy(buffer, tmp, strlen(grsprolog)+256);
    snprintf(label, 256, "%d", prologcount);
    strncpy(tmp, label, strlen(label));
    strncpy(tmp+strlen(label), buffer+strlen("<PROLOGCOUNT>"),
      strlen(buffer+strlen("<PROLOGCOUNT>")));
    tmp[strlen(label)+strlen(buffer+strlen("<PROLOGCOUNT>"))]='\0';
  }
  fprintf(fileptr, "%s", block);
}

void doepilog(FILE *fileptr) {
  char label[257];
  char *block;
  char *buffer;
  char *tmp;

  block=(char *)calloc(strlen(grsepilog)+256, 1);
  buffer=(char *)calloc(strlen(grsepilog)+256, 1);
  strncpy(block, grsepilog, strlen(grsepilog));

  epilogcount++;
  label[256]='\0';

  while (tmp=strstr(block, "<EPILOGCOUNT>")) {
    strncpy(buffer, tmp, strlen(grsepilog)+256);
    snprintf(label, 256, "%d", epilogcount);
    strncpy(tmp, label, strlen(label));
    strncpy(tmp+strlen(label), buffer+strlen("<EPILOGCOUNT>"),
      strlen(buffer+strlen("<EPILOGCOUNT>")));
    tmp[strlen(label)+strlen(buffer+strlen("<EPILOGCOUNT>"))]='\0';
  }
  fprintf(fileptr, "%s", block);
}

void domainprolog(FILE *fileptr) {
  char label[257];
  char *block;
  char *buffer;
  char *tmp;

  block=(char *)calloc(strlen(grsmainprolog)+256, 1);
  buffer=(char *)calloc(strlen(grsmainprolog)+256, 1);
  strncpy(block, grsmainprolog, strlen(grsmainprolog));

  label[256]='\0';

  while (tmp=strstr(block, "<BUFFSIZE>")) {
    strncpy(buffer, tmp, strlen(grsheader)+256);
    snprintf(label, 256, "%d", buffsize);
    strncpy(tmp, label, strlen(label));
    strncpy(tmp+strlen(label), buffer+strlen("<BUFFSIZE>"),
      strlen(buffer+strlen("<BUFFSIZE>")));
    tmp[strlen(label)+strlen(buffer+strlen("<BUFFSIZE>"))]='\0';
  }
  fprintf(fileptr, "%s", block);
}

void doheader(FILE *fileptr) {
  char label[257];
  char *block;
  char *buffer;
  char *tmp;

  block=(char *)calloc(strlen(grsheader)+256, 1);
  buffer=(char *)calloc(strlen(grsheader)+256, 1);
  strncpy(block, grsheader, strlen(grsheader));

  label[256]='\0';

  while (tmp=strstr(block, "<BUFFSIZE>")) {
    strncpy(buffer, tmp, strlen(grsheader)+256);
    snprintf(label, 256, "%d", buffsize);
    strncpy(tmp, label, strlen(label));
    strncpy(tmp+strlen(label), buffer+strlen("<BUFFSIZE>"),
      strlen(buffer+strlen("<BUFFSIZE>")));
    tmp[strlen(label)+strlen(buffer+strlen("<BUFFSIZE>"))]='\0';
  }
  fprintf(fileptr, "%s", block);
}

void dodetectattackepilog(FILE *fileptr) {
  char label[257];
  char *block;
  char *buffer;
  char *tmp;

  block=(char *)calloc(strlen(grsdetectattackepilog)+256, 1);
  buffer=(char *)calloc(strlen(grsdetectattackepilog)+256, 1);
  strncpy(block, grsdetectattackepilog, strlen(grsdetectattackepilog));

  epilogcount++;
  label[256]='\0';

  while (tmp=strstr(block, "<EPILOGCOUNT>")) {
    strncpy(buffer, tmp, strlen(grsdetectattackepilog)+256);
    snprintf(label, 256, "%d", epilogcount);
    strncpy(tmp, label, strlen(label));
    strncpy(tmp+strlen(label), buffer+strlen("<EPILOGCOUNT>"),
      strlen(buffer+strlen("<EPILOGCOUNT>")));
    tmp[strlen(label)+strlen(buffer+strlen("<EPILOGCOUNT>"))]='\0';
  }
  fprintf(fileptr, "%s", block);
}

void doanticrashprolog(FILE *fileptr) {
  char label[513];
  char *block;
  char *buffer;
  char *tmp;

  block=(char *)calloc(strlen(grsanticrashprolog)+512, 1);
  buffer=(char *)calloc(strlen(grsanticrashprolog)+512, 1);
  strncpy(block, grsanticrashprolog, strlen(grsanticrashprolog));

  prologcount++;
  label[512]='\0';

  while (tmp=strstr(block, "<PROLOGCOUNT>")) {
    strncpy(buffer, tmp, strlen(grsanticrashprolog)+512);
    snprintf(label, 512, "%d", prologcount);
    strncpy(tmp, label, strlen(label));
    strncpy(tmp+strlen(label), buffer+strlen("<PROLOGCOUNT>"),
      strlen(buffer+strlen("<PROLOGCOUNT>")));
    tmp[strlen(label)+strlen(buffer+strlen("<PROLOGCOUNT>"))]='\0';
  }
  fprintf(fileptr, "%s", block);
}

void doanticrashepilog(FILE *fileptr) {
  char label[513];
  char *block;
  char *buffer;
  char *tmp;

  block=(char *)calloc(strlen(grsanticrashepilog)+512, 1);
  buffer=(char *)calloc(strlen(grsanticrashepilog)+512, 1);
  strncpy(block, grsanticrashepilog, strlen(grsanticrashepilog));

  label[512]='\0';

  while (tmp=strstr(block, "<EPILOGCOUNT>")) {
    strncpy(buffer, tmp, strlen(grsanticrashepilog)+512);
    snprintf(label, 512, "%d", epilogcount+1);
    strncpy(tmp, label, strlen(label));
    strncpy(tmp+strlen(label), buffer+strlen("<EPILOGCOUNT>"),
      strlen(buffer+strlen("<EPILOGCOUNT>")));
    tmp[strlen(label)+strlen(buffer+strlen("<EPILOGCOUNT>"))]='\0';
  }
  fprintf(fileptr, "%s", block);
}

void parsefile(FILE *srcptr, FILE *destptr) {
  char line[257];

  line[256]='\0';
  while (fgets(line, 257, srcptr)) {
    if (strstr(line, "pushl %ebp")) {
      fprintf(destptr, "%s", line);
      fgets(line, 257, srcptr);
      if (strstr(line, "movl %esp,%ebp")) {
        fprintf(destptr, "%s", line);
	if (anticrash)
          doanticrashprolog(destptr);
	else
	  doprolog(destptr);
      }
    }
    else if (strstr(line, "movl %ebp,%esp")) {
      fgets(line, 257, srcptr);
      if (strstr(line, "popl %ebp")) {
        fgets(line, 257, srcptr);
        if (strstr(line, "ret")) {
	  if (anticrash)
	    doanticrashepilog(destptr);
	  if (detectattack)
	    dodetectattackepilog(destptr);
	  else
	    doepilog(destptr);
	  fprintf(destptr, "	%s\n", "movl %ebp,%esp");
	  fprintf(destptr, "	%s\n", "popl %ebp");
  	  fprintf(destptr, "	%s\n", "ret");
        }
      }
    }
    else if (strstr(line, entrypoint)) {
      fprintf(destptr, "%s", line);
      domainprolog(destptr);
    }
    else
    fprintf(destptr, "%s", line);
  }
}

int main(int argc, char **argv) {
  char srcfile[257];
  char destfile[257];
  FILE *srcptr, *destptr;
  int opt;

  strncpy(entrypoint, "main", 256);
  entrypoint[256]='\0';
  while ((opt=getopt(argc, argv, "cdl:e:")) != EOF) {
    switch (opt) {
      case 'c':
        anticrash=-1;
	break;
      case 'd':
        detectattack=-1;
	break;
      case 'l':
        buffelements=atoi(optarg);
	break;
      case 'e':
        strncpy(entrypoint, optarg, 256);
    }
  }
  if (!(buffelements > 0)) {
    fprintf(stderr, "Error: buffer element number must be a greater than 0\n");
    exit(EXIT_FAILURE);
  }
  buffsize=buffelements*4;
  strncat(entrypoint, ":", 256-strlen(entrypoint));
  argc-=(optind-1);
  argv+=(optind-1);

  if (argc < 3)
    usage(argv[0]);
  strncpy(srcfile, argv[1], 257);
  strncpy(destfile, argv[2], 257);
  srcfile[256]=destfile[256]='\0';

  if ((srcptr=fopen(srcfile, "r")) == NULL) {
    fprintf(stderr, "Error: cannot read %s\n", srcfile);
    exit(EXIT_FAILURE);
  }

  if ((destptr=fopen(destfile, "w")) == NULL) {
    fprintf(stderr, "Error: cannot create %s\n", destfile);
    exit(EXIT_FAILURE);
  }

  doheader(destptr);
  parsefile(srcptr, destptr);
  
  fclose(srcptr);
  fclose(destptr);

  exit(EXIT_SUCCESS);
}