diff --git a/ast/ast.go b/ast/ast.go index f5f6021..fe190f6 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -128,3 +128,28 @@ func (be *BinaryExpression) TokenLiteral() string { return be.Token.Literal } func (be *BinaryExpression) String() string { return fmt.Sprintf("(%s %s %s)", be.Lhs, be.Operator.SymbolString(), be.Rhs) } + +type BlockExpression struct { + Token token.Token // The '{' + Expressions []Expression + ReturnExpression Expression // A expression that does not end with a semicolon, there can only be one of those and it hast to be at the end +} + +func (be *BlockExpression) expressionNode() {} +func (be *BlockExpression) TokenLiteral() string { return be.Token.Literal } +func (be *BlockExpression) String() string { + var builder strings.Builder + + builder.WriteString("({\n") + for _, expr := range be.Expressions { + builder.WriteString("\t") + builder.WriteString(expr.String()) + builder.WriteString(";\n") + } + if be.ReturnExpression != nil { + builder.WriteString(fmt.Sprintf("\t%s\n", be.ReturnExpression.String())) + } + builder.WriteString("})") + + return builder.String() +} diff --git a/lexer/lexer.go b/lexer/lexer.go index 69acd52..d45dd83 100644 --- a/lexer/lexer.go +++ b/lexer/lexer.go @@ -87,6 +87,10 @@ func (l *Lexer) NextToken() token.Token { tok = l.newToken(token.Asterisk) case '/': tok = l.newToken(token.Slash) + case '{': + tok = l.newToken(token.OpenBrack) + case '}': + tok = l.newToken(token.CloseBrack) case '!': if l.peekByte() == '=' { pos := l.position diff --git a/parser/parser.go b/parser/parser.go index 14e1522..1d4b04b 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -13,9 +13,9 @@ type precedence int const ( PrecLowest precedence = iota + PrecComparison PrecSum PrecProduct - PrecComparison ) var precedences = map[token.TokenType]precedence{ @@ -50,6 +50,8 @@ func New(l *lexer.Lexer) *Parser { p.registerPrefixFn(token.Int, p.parseIntegerExpression) p.registerPrefixFn(token.True, p.parseBooleanExpression) p.registerPrefixFn(token.False, p.parseBooleanExpression) + p.registerPrefixFn(token.OpenParen, p.parseGroupedExpression) + p.registerPrefixFn(token.OpenBrack, p.parseBlockExpression) p.infixParseFns = make(map[token.TokenType]infixParseFn) p.registerInfixFn(token.Plus, p.parseBinaryExpression) @@ -128,22 +130,22 @@ func (p *Parser) exprError(invalidToken token.Token, format string, args ...any) } } -func (p *Parser) expect(tt token.TokenType) bool { +func (p *Parser) expect(tt token.TokenType) (bool, ast.Expression) { if p.curToken.Type != tt { p.error(p.curToken, "expected %q, got %q", tt, p.curToken.Type) - return false + return false, &ast.ErrorExpression{InvalidToken: p.curToken} } - return true + return true, nil } -func (p *Parser) expectPeek(tt token.TokenType) bool { +func (p *Parser) expectPeek(tt token.TokenType) (bool, ast.Expression) { if p.peekToken.Type != tt { p.error(p.peekToken, "expected %q, got %q", tt, p.peekToken.Type) p.nextToken() - return false + return false, nil } p.nextToken() - return true + return true, &ast.ErrorExpression{InvalidToken: p.curToken} } func (p *Parser) ParseProgram() *ast.Program { @@ -163,28 +165,28 @@ func (p *Parser) ParseProgram() *ast.Program { } func (p *Parser) parseDeclaration() ast.Declaration { - if !p.expect(token.Fn) { + if ok, _ := p.expect(token.Fn); !ok { return nil } tok := p.curToken - if !p.expectPeek(token.Ident) { + if ok, _ := p.expectPeek(token.Ident); !ok { return nil } name := p.curToken.Literal - if !p.expectPeek(token.OpenParen) { + if ok, _ := p.expectPeek(token.OpenParen); !ok { return nil } - if !p.expectPeek(token.CloseParen) { + if ok, _ := p.expectPeek(token.CloseParen); !ok { return nil } - if !p.expectPeek(token.Equal) { + if ok, _ := p.expectPeek(token.Equal); !ok { return nil } p.nextToken() expr := p.parseExpression(PrecLowest) - if !p.expectPeek(token.Semicolon) { + if ok, _ := p.expectPeek(token.Semicolon); !ok { return nil } @@ -219,8 +221,8 @@ func (p *Parser) parseExpression(precedence precedence) ast.Expression { } func (p *Parser) parseIntegerExpression() ast.Expression { - if !p.expect(token.Int) { - return &ast.ErrorExpression{InvalidToken: p.curToken} + if ok, errExpr := p.expect(token.Int); !ok { + return errExpr } int := &ast.IntegerExpression{ @@ -279,3 +281,40 @@ func (p *Parser) parseBinaryExpression(lhs ast.Expression) ast.Expression { return &ast.BinaryExpression{Lhs: lhs, Rhs: rhs, Operator: op, Token: tok} } + +func (p *Parser) parseGroupedExpression() ast.Expression { + p.expect(token.OpenParen) + + p.nextToken() + expr := p.parseExpression(PrecLowest) + + if ok, errExpr := p.expectPeek(token.CloseParen); !ok { + return errExpr + } + + return expr +} + +func (p *Parser) parseBlockExpression() ast.Expression { + if ok, errExpr := p.expect(token.OpenBrack); !ok { + return errExpr + } + block := &ast.BlockExpression{Token: p.curToken} + + p.nextToken() + for !p.curTokenIs(token.CloseBrack) { + expr := p.parseExpression(PrecLowest) + if p.peekTokenIs(token.Semicolon) { + block.Expressions = append(block.Expressions, expr) + p.nextToken() + p.nextToken() + } else if p.peekTokenIs(token.CloseBrack) { + block.ReturnExpression = expr + p.nextToken() + } else { + return p.exprError(p.peekToken, "expected a ';' or '}' to either end the current expression or block, but got %q instead.", p.peekToken.Type) + } + } + + return block +} diff --git a/parser/parser_test.go b/parser/parser_test.go index a9d3360..509dd9a 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -43,11 +43,11 @@ func runParserTest(test parserTest, t *testing.T) { } for i, decl := range test.expectedProgram.Declarations { - expectDeclarationSame(t, decl, actual.Declarations[i]) + expectDeclaration(t, decl, actual.Declarations[i]) } } -func expectDeclarationSame(t *testing.T, expected ast.Declaration, actual ast.Declaration) { +func expectDeclaration(t *testing.T, expected ast.Declaration, actual ast.Declaration) { t.Helper() switch expected := expected.(type) { @@ -68,6 +68,13 @@ func expectDeclarationSame(t *testing.T, expected ast.Declaration, actual ast.De func expectExpression(t *testing.T, expected ast.Expression, actual ast.Expression) { t.Helper() + if expected == nil { + if actual != nil { + t.Errorf("expected a nil expression but got %v", actual) + } + return + } + switch expected := expected.(type) { case *ast.ErrorExpression: actual, ok := actual.(*ast.ErrorExpression) @@ -109,6 +116,21 @@ func expectExpression(t *testing.T, expected ast.Expression, actual ast.Expressi if booleanExpr.Value != expected.Value { t.Errorf("expected boolean %v, got %v", expected.Value, booleanExpr.Value) } + case *ast.BlockExpression: + blockExpr, ok := actual.(*ast.BlockExpression) + if !ok { + t.Errorf("expected %T, got %T", expected, actual) + return + } + + if len(expected.Expressions) != len(blockExpr.Expressions) { + t.Errorf("expected block with %d expressions, got %d", len(expected.Expressions), len(blockExpr.Expressions)) + return + } + for i, expectedExpression := range expected.Expressions { + expectExpression(t, expectedExpression, blockExpr.Expressions[i]) + } + expectExpression(t, expected.ReturnExpression, blockExpr.ReturnExpression) default: t.Fatalf("unknown expression type %T", expected) } @@ -152,3 +174,45 @@ func TestBinaryExpressions(t *testing.T) { runParserTest(test, t) } + +func TestBlockExpression(t *testing.T) { + test := parserTest{ + input: "fn main() = {\n3;\n{ 3+2 }\n}\n;", + expectedProgram: ast.Program{ + Declarations: []ast.Declaration{ + &ast.FunctionDeclaration{ + Name: "main", + Body: &ast.BlockExpression{ + Expressions: []ast.Expression{ + &ast.IntegerExpression{Value: 3}, + }, + ReturnExpression: &ast.BlockExpression{ + Expressions: []ast.Expression{}, + ReturnExpression: &ast.BinaryExpression{ + Lhs: &ast.IntegerExpression{Value: 3}, + Rhs: &ast.IntegerExpression{Value: 2}, + Operator: ast.Add, + }, + }, + }, + }, + }, + }, + } + runParserTest(test, t) +} + +func TestGroupedExpression(t *testing.T) { + test := parserTest{ + input: "fn main() = (3);", + expectedProgram: ast.Program{ + Declarations: []ast.Declaration{ + &ast.FunctionDeclaration{ + Name: "main", + Body: &ast.IntegerExpression{Value: 3}, + }, + }, + }, + } + runParserTest(test, t) +} diff --git a/test.tt b/test.tt index 57ee435..3e261c8 100644 --- a/test.tt +++ b/test.tt @@ -1 +1,5 @@ -fn main() = 3 + 3 != false; +fn main() = { + 3 + 3; + 3 * 43 * 34 / 34; + 3 +}; diff --git a/token/token.go b/token/token.go index c2c2494..a4053f8 100644 --- a/token/token.go +++ b/token/token.go @@ -32,6 +32,8 @@ const ( Equal TokenType = "=" OpenParen TokenType = "(" CloseParen TokenType = ")" + OpenBrack TokenType = "{" + CloseBrack TokenType = "}" // Binary Operators Plus TokenType = "+"