C11 lexer (and parser?) experiments
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1157 lines
23 KiB

#include <stdio.h>
#include <stdlib.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>
#include <time.h>
#include <fcntl.h>
#include <immintrin.h>
#define MAX(a, b) ((a) > (b) ? (a) : (b))
enum token_kind {
TOKEN_KEYWORD,
TOKEN_IDENTIFIER,
TOKEN_INTEGER_CONSTANT,
TOKEN_FLOATING_CONTANT,
TOKEN_ENUMERATION_CONSTANT,
TOKEN_CHARACTER_CONSTANT,
TOKEN_STRING_LITERAL,
TOKEN_PUNCTUATOR,
TOKEN_HEADER_NAME,
TOKEN_PP_NUMBER,
TOKEN_COUNT,
};
struct str {
char *text;
int size;
};
struct token {
enum token_kind kind;
char *start;
char *end; // one past end
};
static unsigned long long
usec_now(void)
{
struct timespec tp = { 0 };
clock_gettime(CLOCK_MONOTONIC_RAW, &tp);
unsigned long long now = tp.tv_sec * 1000000ULL + tp.tv_nsec / 1000ULL;
return(now);
}
static inline void
advance(struct str *s, int by)
{
s->text += by;
s->size -= by;
}
static int
nondigit(struct str s)
{
char c = s.text[0];
return(('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || (c == '_'));
}
static int
digit(struct str s)
{
char c = s.text[0];
return('0' <= c && c <= '9');
}
static int
nonzero_digit(struct str s)
{
char c = s.text[0];
return('1' <= c && c <= '9');
}
static int
octal_digit(struct str s)
{
char c = s.text[0];
return('0' <= c && c <= '7');
}
static int
hexadecimal_prefix(struct str s)
{
if (s.size >= 2) {
if (s.text[0] == '0') {
if (s.text[1] == 'x' || s.text[1] == 'X') {
return(2);
}
}
}
return(0);
}
static int
hexadecimal_digit(struct str s)
{
char c = s.text[0];
return(digit(s) || ('A' <= c && c <= 'F') || ('a' <= c && c <= 'f'));
}
static int
unsigned_suffix(struct str s)
{
char c = s.text[0];
return(c == 'u' || c == 'U');
}
static int
long_suffix(struct str s)
{
char c = s.text[0];
return(c == 'l' || c == 'L');
}
static int
long_long_suffix(struct str s)
{
if (s.size >= 2) {
char c1 = s.text[0];
char c2 = s.text[1];
if ((c1 == 'l' && c2 == 'l') || (c1 == 'L' && c2 == 'L')) {
return(2);
}
}
return(0);
}
static int
sign(struct str s)
{
char c = s.text[0];
return(c == '+' || c == '-');
}
static int
floating_suffix(struct str s)
{
char c = s.text[0];
return(c == 'f' || c == 'l' || c == 'F' || c == 'L');
}
static int
integer_suffix(struct str s)
{
int sym = 0;
if ((sym = unsigned_suffix(s))) {
advance(&s, sym);
int ll = long_long_suffix(s);
if (ll) {
return(sym + ll);
}
return(sym + long_suffix(s));
} else if ((sym = long_long_suffix(s))) {
advance(&s, sym);
return(sym + unsigned_suffix(s));
} else if ((sym = long_suffix(s))) {
advance(&s, sym);
return(sym + unsigned_suffix(s));
}
return(0);
}
static int
simple_escape_sequence(struct str s)
{
if (s.size >= 2) {
char c1 = s.text[0];
if (c1 == '\\') {
char c2 = s.text[1];
if (c2 == '\'' || c2 == '\"' || c2 == '?' ||
c2 == '\\' || c2 == 'a' || c2 == 'b' ||
c2 == 'f' || c2 == 'n' || c2 == 'r' ||
c2 == 't' || c2 == 'v') {
return(2);
}
}
}
return(0);
}
static int
octal_escape_sequence(struct str s)
{
int start = s.size;
if (s.size && s.text[0] == '\\') {
advance(&s, 1);
if (octal_digit(s)) {
advance(&s, 1);
if (octal_digit(s)) {
advance(&s, 1);
if (octal_digit(s)) {
advance(&s, 1);
}
}
return(start - s.size);
}
}
return(0);
}
static int
hexadecimal_escape_sequence(struct str s)
{
int start = s.size;
if (s.size >= 2) {
if (s.text[0] == '\\' && s.text[1] == 'x') {
advance(&s, 2);
if (hexadecimal_digit(s)) {
advance(&s, 1);
for (;;) {
if (hexadecimal_digit(s)) {
advance(&s, 1);
} else {
break;
}
}
return(start - s.size);
}
}
}
return(0);
}
static int
hex_quad(struct str s)
{
if (s.size >= 4) {
if (hexadecimal_digit(s)) {
advance(&s, 1);
if (hexadecimal_digit(s)) {
advance(&s, 1);
if (hexadecimal_digit(s)) {
advance(&s, 1);
if (hexadecimal_digit(s)) {
return(4);
}
}
}
}
}
return(0);
}
static int
universal_character_name(struct str s)
{
if (s.size >= 2 && s.text[0] == '\\') {
if (s.text[1] == 'u') {
advance(&s, 2);
int hq = hex_quad(s);
if (hq) {
return(2 + hq);
}
} else if (s.text[1] == 'U') {
advance(&s, 2);
int hq1 = hex_quad(s);
if (hq1) {
advance(&s, hq1);
int hq2 = hex_quad(s);
if (hq2) {
return(2 + hq1 + hq2);
}
}
}
}
return(0);
}
static int
identifier_nondigit(struct str s)
{
int sym = 0;
if ((sym = nondigit(s)) || (sym = universal_character_name(s))) {
return(sym);
}
return(0);
}
static int
identifier(struct str s)
{
int start = s.size;
int in = identifier_nondigit(s);
if (in) {
advance(&s, in);
int sym = 0;
for (;;) {
if ((sym = identifier_nondigit(s)) || (sym = digit(s))) {
advance(&s, sym);
} else {
break;
}
}
return(start - s.size);
}
return(0);
}
static int
decimal_constant(struct str s)
{
int start = s.size;
if (nonzero_digit(s)) {
advance(&s, 1);
for (;;) {
if (digit(s)) {
advance(&s, 1);
} else {
break;
}
}
return(start - s.size);
}
return(0);
}
static int
octal_constant(struct str s)
{
int start = s.size;
if (s.text[0] == '0') {
advance(&s, 1);
for (;;) {
if (octal_digit(s)) {
advance(&s, 1);
} else {
break;
}
}
return(start - s.size);
}
return(0);
}
static int
hexadecimal_constant(struct str s)
{
int start = s.size;
int hp = hexadecimal_prefix(s);
if (hp) {
advance(&s, hp);
if (hexadecimal_digit(s)) {
advance(&s, 1);
for (;;) {
if (hexadecimal_digit(s)) {
advance(&s, 1);
} else {
break;
}
}
return(start - s.size);
}
}
return(0);
}
static int
integer_constant(struct str s)
{
int sym = 0;
if ((sym = hexadecimal_constant(s)) || (sym = octal_constant(s)) || (sym = decimal_constant(s))) {
advance(&s, sym);
return(sym + integer_suffix(s));
}
return(0);
}
static int
digit_sequence(struct str s)
{
int start = s.size;
if (digit(s)) {
advance(&s, 1);
for (;;) {
if (digit(s)) {
advance(&s, 1);
} else {
break;
}
}
return(start - s.size);
}
return(0);
}
static int
fractional_constant(struct str s)
{
int ds1 = digit_sequence(s);
s.text += ds1;
s.size -= ds1;
if (s.text[0] == '.') {
advance(&s, 1);
int ds2 = digit_sequence(s);
advance(&s, ds2);
if (ds1 > 0 || ds2 > 0) {
return(ds1 + ds2 + 1);
}
}
return(0);
}
static int
exponent_part(struct str s)
{
int start = s.size;
if (s.text[0] == 'e' || s.text[0] == 'E') {
advance(&s, 1);
int sgn = sign(s);
advance(&s, sgn);
int ds = digit_sequence(s);
if (ds) {
advance(&s, ds);
return(start - s.size);
}
}
return(0);
}
static int
decimal_floating_constant(struct str s)
{
int sym = 0;
if ((sym = fractional_constant(s))) {
advance(&s, sym);
int ep = exponent_part(s);
advance(&s, ep);
return(sym + ep + floating_suffix(s));
} else if ((sym = digit_sequence(s))) {
advance(&s, sym);
int ep = 0;
if ((ep = exponent_part(s))) {
advance(&s, ep);
return(sym + ep + floating_suffix(s));
}
}
return(0);
}
static int
hexadecimal_digit_sequence(struct str s)
{
int start = s.size;
if (hexadecimal_digit(s)) {
advance(&s, 1);
for (;;) {
if (hexadecimal_digit(s)) {
advance(&s, 1);
} else {
break;
}
}
return(start - s.size);
}
return(0);
}
static int
hexadecimal_fractional_constant(struct str s)
{
int hds1 = hexadecimal_digit_sequence(s);
advance(&s, hds1);
if (s.text[0] == '.') {
advance(&s, 1);
int hds2 = hexadecimal_digit_sequence(s);
if (hds1 > 0 || hds2 > 0) {
return(hds1 + hds2 + 1);
}
}
return(0);
}
static int
binary_exponent_part(struct str s)
{
int start = s.size;
if (s.text[0] == 'p' || s.text[0] == 'P') {
advance(&s, 1);
int sgn = sign(s);
advance(&s, sgn);
int ds = digit_sequence(s);
if (ds) {
advance(&s, ds);
return(start - s.size);
}
}
return(0);
}
static int
hexadecimal_floating_constant(struct str s)
{
int hp = 0;
int start = s.size;
if ((hp = hexadecimal_prefix(s))) {
advance(&s, hp);
int hfc = 0;
int hds = 0;
if ((hfc = hexadecimal_fractional_constant(s))) {
advance(&s, hfc);
int bep = binary_exponent_part(s);
if (bep) {
advance(&s, bep);
return(start - s.size + floating_suffix(s));
}
} else if ((hds = hexadecimal_digit_sequence(s))) {
advance(&s, hds);
int bep = binary_exponent_part(s);
if (bep) {
advance(&s, bep);
return(start - s.size + floating_suffix(s));
}
}
}
return(0);
}
static int
floating_constant(struct str s)
{
int sym = 0;
if ((sym = decimal_floating_constant(s)) || (sym = hexadecimal_floating_constant(s))) {
return(sym);
}
return(0);
}
#if 0
static int
enumeration_constant(struct str s)
{
int i = identifier(s);
return(i);
}
#endif
static int
escape_sequence(struct str s)
{
int sym = 0;
if ((sym = simple_escape_sequence(s)) || (sym = octal_escape_sequence(s)) ||
(sym = hexadecimal_escape_sequence(s)) || (sym = universal_character_name(s))) {
return(sym);
}
return(0);
}
static int
c_char(struct str s)
{
int sym = 0;
if ((sym = escape_sequence(s))) {
return(sym);
}
char c = s.text[0];
return(c != '\'' && c != '\\' && c != '\n');
}
static int
h_char(struct str s)
{
char c = s.text[0];
return(c != '\n' && c != '>');
}
static int
q_char(struct str s)
{
char c = s.text[0];
return(c != '\n' && c != '\"');
}
static int
c_char_sequence(struct str s)
{
int start = s.size;
int sc = 0;
if ((sc = c_char(s))) {
advance(&s, sc);
for (;;) {
if ((sc = c_char(s))) {
advance(&s, sc);
} else {
break;
}
}
return(start - s.size);
}
return(0);
}
static int
character_constant(struct str s)
{
int start = s.size;
int ok = 0;
if (s.text[0] == '\'') {
advance(&s, 1);
ok = 1;
} else if (s.size >= 2 && s.text[0] == 'L' && s.text[1] == '\'') {
advance(&s, 2);
ok = 1;
} else if (s.size >= 2 && s.text[0] == 'u' && s.text[1] == '\'') {
advance(&s, 2);
ok = 1;
} else if (s.size >= 2 && s.text[0] == 'U' && s.text[1] == '\'') {
advance(&s, 2);
ok = 1;
}
if (ok) {
int ccs = c_char_sequence(s);
if (ccs) {
advance(&s, ccs);
if (s.text[0] == '\'') {
advance(&s, 1);
return(start - s.size);
}
}
}
return(0);
}
static int
constant(struct str s)
{
int sym = 0;
if ((sym = floating_constant(s)) || (sym = integer_constant(s)) || (sym = character_constant(s))) {
/* || (sym = enumeration_constant(s)) */
return(sym);
}
return(0);
}
static int
whitespace(struct str s)
{
int start = s.size;
//int spaces = 0x20090a0d;
while (s.text[0] == ' ' || s.text[0] == '\t' || s.text[0] == '\n' || s.text[0] == '\r') {
advance(&s, 1);
}
return(start - s.size);
}
static int
s_char(struct str s)
{
int sym = 0;
if ((sym = escape_sequence(s))) {
return(sym);
}
char c = s.text[0];
return(c != '\"' && c != '\\' && c != '\n');
}
static int
s_char_sequence(struct str s)
{
int start = s.size;
int sc = 0;
if ((sc = s_char(s))) {
advance(&s, sc);
for (;;) {
if ((sc = s_char(s))) {
advance(&s, sc);
} else {
break;
}
}
return(start - s.size);
}
return(0);
}
static int
encoding_prefix(struct str s)
{
if (s.size >= 2 && s.text[0] == 'u' && s.text[1] == '8') {
return(2);
}
char c = s.text[0];
return(c == 'u' || c == 'U' || c == 'L');
}
static int
string_literal(struct str s)
{
int start = s.size;
int ep = encoding_prefix(s);
advance(&s, ep);
if (s.text[0] == '\"') {
advance(&s, 1);
int scs = s_char_sequence(s);
advance(&s, scs);
if (s.text[0] == '\"') {
advance(&s, 1);
return(start - s.size);
}
}
return(0);
}
static int
punctuator(struct str s)
{
if (s.size >= 4) {
if (s.text[0] == '%' && s.text[1] == ':' && s.text[2] == '%' && s.text[3] == ':')
{
return(4);
}
}
if (s.size >= 3) {
char c1 = s.text[0];
char c2 = s.text[1];
char c3 = s.text[2];
if ((c1 == '.' && c2 == '.' && c3 == '.') ||
(c1 == '<' && c2 == '<' && c3 == '=') ||
(c1 == '>' && c2 == '>' && c3 == '='))
{
return(3);
}
}
if (s.size >= 2) {
char c1 = s.text[0];
char c2 = s.text[1];
if ((c1 == '-' && c2 == '>') || (c1 == '+' && c2 == '+') ||
(c1 == '-' && c2 == '-') || (c1 == '<' && c2 == '<') ||
(c1 == '>' && c2 == '>') || (c1 == '<' && c2 == '=') ||
(c1 == '>' && c2 == '=') || (c1 == '=' && c2 == '=') ||
(c1 == '!' && c2 == '=') || (c1 == '&' && c2 == '&') ||
(c1 == '|' && c2 == '|') || (c1 == '*' && c2 == '=') ||
(c1 == '/' && c2 == '=') || (c1 == '%' && c2 == '=') ||
(c1 == '+' && c2 == '=') || (c1 == '-' && c2 == '=') ||
(c1 == '&' && c2 == '=') || (c1 == '^' && c2 == '=') ||
(c1 == '|' && c2 == '=') || (c1 == '#' && c2 == '#') ||
(c1 == '<' && c2 == ':') || (c1 == ':' && c2 == '>') ||
(c1 == '<' && c2 == '%') || (c1 == '>' && c2 == '%') ||
(c1 == '%' && c2 == ':'))
{
return(2);
}
}
char c = s.text[0];
if (c == '[' || c == ']' || c == '(' || c == ')' ||
c == '{' || c == '}' || c == '.' || c == '&' ||
c == '*' || c == '+' || c == '-' || c == '~' ||
c == '!' || c == '/' || c == '%' || c == '<' ||
c == '>' || c == '^' || c == '|' || c == '?' ||
c == ':' || c == ';' || c == '=' || c == ',' ||
c == '#')
{
return(1);
}
return(0);
}
static int
h_char_sequence(struct str s)
{
int start = s.size;
int hc = 0;
if ((hc = h_char(s))) {
advance(&s, hc);
for (;;) {
if ((hc = h_char(s))) {
advance(&s, hc);
} else {
break;
}
}
return(start - s.size);
}
return(0);
}
static int
q_char_sequence(struct str s)
{
int start = s.size;
int qc = 0;
if ((qc = q_char(s))) {
advance(&s, qc);
for (;;) {
if ((qc = q_char(s))) {
advance(&s, qc);
} else {
break;
}
}
return(start - s.size);
}
return(0);
}
static int
header_name(struct str s)
{
if (s.text[0] == '<') {
advance(&s, 1);
int hcs = h_char_sequence(s);
if (hcs) {
advance(&s, hcs);
if (s.text[0] == '>') {
return(hcs + 2);
}
}
}
if (s.text[0] == '\"') {
advance(&s, 1);
int qcs = q_char_sequence(s);
if (qcs) {
advance(&s, qcs);
if (s.text[0] == '\"') {
return(qcs + 2);
}
}
}
return(0);
}
static int
comment(struct str s)
{
int start = s.size;
if (s.size >= 2) {
if (s.text[0] == '/' && s.text[1] == '/') {
/* single-line comment */
advance(&s, 2);
while (s.size) {
if (s.text[0] == '\n') {
advance(&s, 1);
break;
}
advance(&s, 1);
}
return(start - s.size);
}
if (s.text[0] == '/' && s.text[1] == '*') {
/* multi-line comment */
advance(&s, 2);
__m128i mask = _mm_setr_epi8('*', '/', '*', '/', '*', '/', '*', '/', '*', '/', '*', '/', '*', '/', '*', '/');
__m128i mask_sus = _mm_setr_epi8(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, '*');
while (s.size > 16) {
__m128i chunk = _mm_loadu_si128((__m128i *)(s.text));
__m128i v1 = _mm_cmpeq_epi16(chunk, mask);
__m128i v2 = _mm_cmpeq_epi16(_mm_bslli_si128(chunk, 1), mask);
__m128i v3 = _mm_cmpeq_epi8(chunk, mask_sus);
__m128i v12 = _mm_or_si128(v1, v2);
__m128i v123 = _mm_or_si128(v12, v3);
if (!_mm_testz_si128(v123, v123)) {
break;
}
advance(&s, 16);
}
while (s.size) {
if (s.size >= 2 && s.text[0] == '*' && s.text[1] == '/') {
advance(&s, 2);
break;
}
advance(&s, 1);
}
if (s.size) {
return(start - s.size);
}
}
}
return(0);
}
static struct token *
lex(char *text, int size)
{
struct str s = { text, size };
while (s.size) {
int sym = whitespace(s);
advance(&s, sym);
if ((sym = comment(s))) {
//printf("Comment: ");
//printf("%.*s\n", sym, s.text);
} else {
int sym_constant = constant(s);
int sym_punctuator = punctuator(s);
int sym_string = string_literal(s);
int sym_header = header_name(s);
int sym_identifier = identifier(s);
sym = MAX(sym_constant, MAX(sym_punctuator, MAX(sym_string, MAX(sym_header, sym_identifier))));
#if 0
if (sym == sym_constant) {
printf("Constant: ");
} else if (sym == sym_punctuator) {
printf("Punctuator: ");
} else if (sym == sym_string) {
printf("String: ");
} else if (sym == sym_header) {
printf("Header: ");
} else if (sym == sym_identifier) {
printf("Identifier: ");
}
#endif
}
if (sym) {
advance(&s, sym);
} else if (s.size == 1 && s.text[0] == '\0') {
break;
} else {
fprintf(stderr, "Error!\n");
break;
}
}
return(NULL);
}
static void
preprocess_scalar(char *data, int chunk_size, int full_size)
{
for (int i = 0; i < chunk_size; ++i) {
if (data[i] == '\\') {
for (int j = i + 1; j < full_size; ++j) {
if (data[j] == '\n') {
data[j] = ' ';
data[i] = ' ';
break;
} else if (data[j] != ' ' && data[j] != '\t') {
break;
}
}
}
}
}
static void
preprocess(char *data, int size)
{
int chunk_size = 16;
int whole = size & (~(chunk_size - 1));
__m128i mask = _mm_set1_epi8('\\');
for (int i = 0; i < whole; i += chunk_size) {
__m128i chunk = _mm_loadu_si128((__m128i *) (data + i));
__m128i match = _mm_cmpeq_epi8(chunk, mask);
if (!_mm_testz_si128(match, match)) {
preprocess_scalar(data + i, chunk_size, size - i);
}
}
preprocess_scalar(data + whole, size - whole, size - whole);
}
static void
run(char *data, int size)
{
unsigned long long before = usec_now();
{
preprocess(data, size);
struct token *tokens = lex(data, size);
(void) tokens;
}
unsigned long long after = usec_now();
float dt = after - before;
fprintf(stderr, "%.2fms, %.2fMB/s\n", dt / 1000, size / dt);
}
int
main(int argc, char **argv)
{
if (argc != 2) {
fprintf(stderr, "Usage: %s input_file.c\n", argv[0]);
return(1);
}
char *file = argv[1];
int fd = open(file, O_RDONLY);
if (fd == -1) {
perror("open");
return(1);
}
struct stat sb = { 0 };
if (fstat(fd, &sb) == -1) {
perror("fstat");
return(1);
}
int size = (int) sb.st_size;
char *data = mmap(NULL, size, PROT_READ | PROT_WRITE, MAP_PRIVATE, fd, 0);
if (size && data[size - 1] != '\n') {
fprintf(stderr, "No terminating new line. Fuck you!\n");
return(1);
}
data[size - 1] = '\0';
if (data == MAP_FAILED) {
perror("mmap");
return(1);
}
run(data, size);
return(0);
}