diff --git a/ast/ast.go b/ast/ast.go index 70feafe..0b9e983 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -215,3 +215,15 @@ func (vr *VariableReference) TokenLiteral() string { return vr.Token.Literal } func (vr *VariableReference) String() string { return fmt.Sprintf("%s", vr.Identifier) } + +type AssignmentExpression struct { + Token token.Token // The Equal + Lhs Expression + Rhs Expression +} + +func (ae *AssignmentExpression) expressionNode() {} +func (ae *AssignmentExpression) TokenLiteral() string { return ae.Token.Literal } +func (ae *AssignmentExpression) String() string { + return fmt.Sprintf("%s = %s", ae.Lhs.String(), ae.Rhs.String()) +} diff --git a/flake.nix b/flake.nix index 62d4194..c184705 100644 --- a/flake.nix +++ b/flake.nix @@ -59,7 +59,7 @@ in { default = pkgs.mkShell { - buildInputs = with pkgs; [ go gopls gotools go-tools ]; + buildInputs = with pkgs; [ go gopls gotools go-tools qbe fasm ]; }; }); diff --git a/parser/parser.go b/parser/parser.go index 73963aa..7b6c56c 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -16,6 +16,7 @@ const ( PrecComparison PrecSum PrecProduct + PrecAssignment ) var precedences = map[token.TokenType]precedence{ @@ -29,6 +30,7 @@ var precedences = map[token.TokenType]precedence{ token.GreaterThanEqual: PrecComparison, token.LessThan: PrecComparison, token.LessThanEqual: PrecComparison, + token.Equal: PrecAssignment, } type ErrorCallback func(token.Token, string, ...any) @@ -71,6 +73,8 @@ func New(l *lexer.Lexer) *Parser { p.registerInfixFn(token.LessThan, p.parseBinaryExpression) p.registerInfixFn(token.LessThanEqual, p.parseBinaryExpression) + p.registerInfixFn(token.Equal, p.parseAssignmentExpression) + p.nextToken() p.nextToken() @@ -340,9 +344,10 @@ func (p *Parser) parseVariable() ast.Expression { return errExpr } - if p.peekTokenIs(token.Colon) { + switch p.peekToken.Type { + case token.Colon: return p.parseVariableDeclaration() - } else { + default: return &ast.VariableReference{ Token: p.curToken, Identifier: p.curToken.Literal, @@ -415,3 +420,20 @@ func (p *Parser) parseBinaryExpression(lhs ast.Expression) ast.Expression { return &ast.BinaryExpression{Lhs: lhs, Rhs: rhs, Operator: op, Token: tok} } + +func (p *Parser) parseAssignmentExpression(lhs ast.Expression) ast.Expression { + if ok, errExpr := p.expect(token.Equal); !ok { + return errExpr + } + + varAss := &ast.AssignmentExpression{ + Token: p.curToken, + Lhs: lhs, + } + + p.nextToken() + + varAss.Rhs = p.parseExpression(PrecLowest) + + return varAss +} diff --git a/parser/parser_test.go b/parser/parser_test.go index 509dd9a..13683ba 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -131,6 +131,33 @@ func expectExpression(t *testing.T, expected ast.Expression, actual ast.Expressi expectExpression(t, expectedExpression, blockExpr.Expressions[i]) } expectExpression(t, expected.ReturnExpression, blockExpr.ReturnExpression) + case *ast.VariableDeclaration: + varDecl, ok := actual.(*ast.VariableDeclaration) + if !ok { + t.Errorf("expected %T, got %T", expected, actual) + return + } + + if expected.Identifier != varDecl.Identifier { + t.Errorf("expected variable identifier to be %q, got %q", expected.Identifier, varDecl.Identifier) + } + + if expected.Type != varDecl.Type { + t.Errorf("expected variable type to be %q, got %q", expected.Type, varDecl.Type) + } + + expectExpression(t, expected.InitializingExpression, varDecl.InitializingExpression) + case *ast.VariableReference: + varRef, ok := actual.(*ast.VariableReference) + + if !ok { + t.Errorf("expected %T, got %T", expected, actual) + return + } + + if expected.Identifier != varRef.Identifier { + t.Errorf("expected variable reference identifier to be %q but got %q", expected.Identifier, varRef.Identifier) + } default: t.Fatalf("unknown expression type %T", expected) } @@ -216,3 +243,27 @@ func TestGroupedExpression(t *testing.T) { } runParserTest(test, t) } + +func TestVariableExpression(t *testing.T) { + test := parserTest{ + input: "fn main() = { x : u32 = 3; x };", + expectedProgram: ast.Program{ + Declarations: []ast.Declaration{ + &ast.FunctionDeclaration{ + Name: "main", + Body: &ast.BlockExpression{ + Expressions: []ast.Expression{ + &ast.VariableDeclaration{ + InitializingExpression: &ast.IntegerExpression{Value: 3}, + Identifier: "x", + Type: "u32", + }, + }, + ReturnExpression: &ast.VariableReference{Identifier: "x"}, + }, + }, + }, + }, + } + runParserTest(test, t) +} diff --git a/test.tt b/test.tt index 486e533..ee59cbe 100644 --- a/test.tt +++ b/test.tt @@ -1,11 +1,8 @@ fn main() = { hi: i64 = 4; - if hi == 2 { - hi = 3 - } else { - hi = 2 - } + if hi == 2 in hi = 3 + else hi = 2; hi };