a lot of stuff

This commit is contained in:
Robin Bärtschi 2025-02-23 19:09:03 +01:00
parent ee26d1371c
commit 1b2ebd2361
11 changed files with 281 additions and 56 deletions

View File

@ -9,6 +9,7 @@ import (
type Node interface { type Node interface {
TokenLiteral() string TokenLiteral() string
Tok() token.Token
String() string String() string
} }
@ -33,6 +34,13 @@ func (p *Program) TokenLiteral() string {
return "" return ""
} }
func (p *Program) Tok() token.Token {
if len(p.Declarations) > 0 {
return p.Declarations[0].Tok()
}
return token.Token{}
}
func (p *Program) String() string { func (p *Program) String() string {
var builder strings.Builder var builder strings.Builder
@ -52,6 +60,7 @@ type FunctionDeclaration struct {
func (fd *FunctionDeclaration) declarationNode() {} func (fd *FunctionDeclaration) declarationNode() {}
func (fd *FunctionDeclaration) TokenLiteral() string { return fd.Token.Literal } func (fd *FunctionDeclaration) TokenLiteral() string { return fd.Token.Literal }
func (fd *FunctionDeclaration) Tok() token.Token { return fd.Token }
func (fd *FunctionDeclaration) String() string { func (fd *FunctionDeclaration) String() string {
return fmt.Sprintf("fn %v() = %v;", fd.Name, fd.Body.String()) return fmt.Sprintf("fn %v() = %v;", fd.Name, fd.Body.String())
} }
@ -63,6 +72,7 @@ type ErrorExpression struct {
func (e *ErrorExpression) expressionNode() {} func (e *ErrorExpression) expressionNode() {}
func (e *ErrorExpression) TokenLiteral() string { return e.InvalidToken.Literal } func (e *ErrorExpression) TokenLiteral() string { return e.InvalidToken.Literal }
func (e *ErrorExpression) Tok() token.Token { return e.InvalidToken }
func (e *ErrorExpression) String() string { return "<ERROR EXPR>" } func (e *ErrorExpression) String() string { return "<ERROR EXPR>" }
type IntegerExpression struct { type IntegerExpression struct {
@ -72,6 +82,7 @@ type IntegerExpression struct {
func (ie *IntegerExpression) expressionNode() {} func (ie *IntegerExpression) expressionNode() {}
func (ie *IntegerExpression) TokenLiteral() string { return ie.Token.Literal } func (ie *IntegerExpression) TokenLiteral() string { return ie.Token.Literal }
func (ie *IntegerExpression) Tok() token.Token { return ie.Token }
func (ie *IntegerExpression) String() string { return ie.Token.Literal } func (ie *IntegerExpression) String() string { return ie.Token.Literal }
type BooleanExpression struct { type BooleanExpression struct {
@ -81,6 +92,7 @@ type BooleanExpression struct {
func (be *BooleanExpression) expressionNode() {} func (be *BooleanExpression) expressionNode() {}
func (be *BooleanExpression) TokenLiteral() string { return be.Token.Literal } func (be *BooleanExpression) TokenLiteral() string { return be.Token.Literal }
func (be *BooleanExpression) Tok() token.Token { return be.Token }
func (be *BooleanExpression) String() string { return be.Token.Literal } func (be *BooleanExpression) String() string { return be.Token.Literal }
//go:generate stringer -type=BinaryOperator //go:generate stringer -type=BinaryOperator
@ -137,6 +149,7 @@ type BinaryExpression struct {
func (be *BinaryExpression) expressionNode() {} func (be *BinaryExpression) expressionNode() {}
func (be *BinaryExpression) TokenLiteral() string { return be.Token.Literal } func (be *BinaryExpression) TokenLiteral() string { return be.Token.Literal }
func (be *BinaryExpression) Tok() token.Token { return be.Token }
func (be *BinaryExpression) String() string { func (be *BinaryExpression) String() string {
return fmt.Sprintf("(%s %s %s)", be.Lhs, be.Operator.SymbolString(), be.Rhs) return fmt.Sprintf("(%s %s %s)", be.Lhs, be.Operator.SymbolString(), be.Rhs)
} }
@ -149,6 +162,7 @@ type BlockExpression struct {
func (be *BlockExpression) expressionNode() {} func (be *BlockExpression) expressionNode() {}
func (be *BlockExpression) TokenLiteral() string { return be.Token.Literal } func (be *BlockExpression) TokenLiteral() string { return be.Token.Literal }
func (be *BlockExpression) Tok() token.Token { return be.Token }
func (be *BlockExpression) String() string { func (be *BlockExpression) String() string {
var builder strings.Builder var builder strings.Builder
@ -170,12 +184,13 @@ type IfExpression struct {
Token token.Token // The 'if' token Token token.Token // The 'if' token
Condition Expression Condition Expression
Then Expression Then Expression
// Can be nil // NOTE: Can be nil
Else Expression Else Expression
} }
func (ie *IfExpression) expressionNode() {} func (ie *IfExpression) expressionNode() {}
func (ie *IfExpression) TokenLiteral() string { return ie.Token.Literal } func (ie *IfExpression) TokenLiteral() string { return ie.Token.Literal }
func (ie *IfExpression) Tok() token.Token { return ie.Token }
func (ie *IfExpression) String() string { func (ie *IfExpression) String() string {
var builder strings.Builder var builder strings.Builder
@ -200,6 +215,7 @@ type VariableDeclaration struct {
func (vd *VariableDeclaration) expressionNode() {} func (vd *VariableDeclaration) expressionNode() {}
func (vd *VariableDeclaration) TokenLiteral() string { return vd.Token.Literal } func (vd *VariableDeclaration) TokenLiteral() string { return vd.Token.Literal }
func (vd *VariableDeclaration) Tok() token.Token { return vd.Token }
func (vd *VariableDeclaration) String() string { func (vd *VariableDeclaration) String() string {
return fmt.Sprintf("%s : %v = %s", vd.Identifier, vd.Type, vd.InitializingExpression) return fmt.Sprintf("%s : %v = %s", vd.Identifier, vd.Type, vd.InitializingExpression)
} }
@ -211,6 +227,7 @@ type VariableReference struct {
func (vr *VariableReference) expressionNode() {} func (vr *VariableReference) expressionNode() {}
func (vr *VariableReference) TokenLiteral() string { return vr.Token.Literal } func (vr *VariableReference) TokenLiteral() string { return vr.Token.Literal }
func (vr *VariableReference) Tok() token.Token { return vr.Token }
func (vr *VariableReference) String() string { func (vr *VariableReference) String() string {
return fmt.Sprintf("%s", vr.Identifier) return fmt.Sprintf("%s", vr.Identifier)
} }
@ -223,6 +240,7 @@ type AssignmentExpression struct {
func (ae *AssignmentExpression) expressionNode() {} func (ae *AssignmentExpression) expressionNode() {}
func (ae *AssignmentExpression) TokenLiteral() string { return ae.Token.Literal } func (ae *AssignmentExpression) TokenLiteral() string { return ae.Token.Literal }
func (ae *AssignmentExpression) Tok() token.Token { return ae.Token }
func (ae *AssignmentExpression) String() string { func (ae *AssignmentExpression) String() string {
return fmt.Sprintf("%s = %s", ae.Lhs.String(), ae.Rhs.String()) return fmt.Sprintf("%s = %s", ae.Lhs.String(), ae.Rhs.String())
} }

View File

@ -145,7 +145,14 @@ func (rft *funcTask) WithName(name string) {
rft.name = name rft.name = name
} }
func build(outputWriter io.Writer, input string, output string, toPrint ToPrintFlags, backend asm.Backend) error { func build(outputWriter io.Writer, input string, output string, toPrint ToPrintFlags, backend asm.Backend) (err error) {
defer func() {
if panicErr := recover(); panicErr != nil {
err = fmt.Errorf("panic in build: %#v", panicErr)
}
}()
file, err := os.Open(input) file, err := os.Open(input)
if err != nil { if err != nil {
return fmt.Errorf("could not open file %q because: %v", input, err) return fmt.Errorf("could not open file %q because: %v", input, err)
@ -217,7 +224,7 @@ func build(outputWriter io.Writer, input string, output string, toPrint ToPrintF
} }
return nil return
} }
type taskResult struct { type taskResult struct {

View File

@ -27,7 +27,7 @@ type Lexer struct {
} }
func New(input string, file string) (*Lexer, error) { func New(input string, file string) (*Lexer, error) {
l := &Lexer{input: input, file: file} l := &Lexer{input: input, file: file, lineCount: 1}
if err := l.readChar(); err != nil { if err := l.readChar(); err != nil {
return nil, err return nil, err
} }

View File

@ -59,7 +59,7 @@ func main() {
toPrint |= build.PrintIr toPrint |= build.PrintIr
} }
logger := log.New(os.Stderr, "", log.Lshortfile) _ = log.New(os.Stderr, "", log.Lshortfile)
backend := asm.Fasm backend := asm.Fasm
if *qbe { if *qbe {
@ -68,7 +68,7 @@ func main() {
err := build.NewSourceProgram(input, output).Build(backend, *emitAsmOnly, build.ToPrintFlags(toPrint)) err := build.NewSourceProgram(input, output).Build(backend, *emitAsmOnly, build.ToPrintFlags(toPrint))
if err != nil { if err != nil {
logger.Fatalln(err) os.Stderr.WriteString(fmt.Sprintf("%v\n", err.Error()))
term.Exit(1) term.Exit(1)
} }
} }

View File

@ -15,6 +15,7 @@ import (
type Node interface { type Node interface {
TokenLiteral() string TokenLiteral() string
Tok() token.Token
String() string String() string
} }
@ -62,6 +63,7 @@ var _ Declaration = &FunctionDeclaration{}
func (fd *FunctionDeclaration) declarationNode() {} func (fd *FunctionDeclaration) declarationNode() {}
func (fd *FunctionDeclaration) TokenLiteral() string { return fd.Token.Literal } func (fd *FunctionDeclaration) TokenLiteral() string { return fd.Token.Literal }
func (fd *FunctionDeclaration) Tok() token.Token { return fd.Token }
func (fd *FunctionDeclaration) String() string { func (fd *FunctionDeclaration) String() string {
return fmt.Sprintf("fn %v(): %v = %v;", fd.Name, fd.ReturnType.Name(), fd.Body.String()) return fmt.Sprintf("fn %v(): %v = %v;", fd.Name, fd.ReturnType.Name(), fd.Body.String())
} }
@ -78,6 +80,7 @@ func (ie *IntegerExpression) Type() types.Type {
return types.I64 return types.I64
} }
func (ie *IntegerExpression) TokenLiteral() string { return ie.Token.Literal } func (ie *IntegerExpression) TokenLiteral() string { return ie.Token.Literal }
func (ie *IntegerExpression) Tok() token.Token { return ie.Token }
func (ie *IntegerExpression) String() string { return ie.Token.Literal } func (ie *IntegerExpression) String() string { return ie.Token.Literal }
type BooleanExpression struct { type BooleanExpression struct {
@ -92,6 +95,7 @@ func (ie *BooleanExpression) Type() types.Type {
return types.Bool return types.Bool
} }
func (be *BooleanExpression) TokenLiteral() string { return be.Token.Literal } func (be *BooleanExpression) TokenLiteral() string { return be.Token.Literal }
func (be *BooleanExpression) Tok() token.Token { return be.Token }
func (be *BooleanExpression) String() string { return be.Token.Literal } func (be *BooleanExpression) String() string { return be.Token.Literal }
type BinaryExpression struct { type BinaryExpression struct {
@ -108,6 +112,7 @@ func (be *BinaryExpression) Type() types.Type {
return be.ResultType return be.ResultType
} }
func (be *BinaryExpression) TokenLiteral() string { return be.Token.Literal } func (be *BinaryExpression) TokenLiteral() string { return be.Token.Literal }
func (be *BinaryExpression) Tok() token.Token { return be.Token }
func (be *BinaryExpression) String() string { func (be *BinaryExpression) String() string {
return fmt.Sprintf("(%s %s %s :> %s)", be.Lhs, be.Operator.SymbolString(), be.Rhs, be.ResultType.Name()) return fmt.Sprintf("(%s %s %s :> %s)", be.Lhs, be.Operator.SymbolString(), be.Rhs, be.ResultType.Name())
} }
@ -119,11 +124,14 @@ type BlockExpression struct {
ReturnType types.Type ReturnType types.Type
} }
var _ Expression = &BlockExpression{}
func (be *BlockExpression) expressionNode() {} func (be *BlockExpression) expressionNode() {}
func (be *BlockExpression) Type() types.Type { func (be *BlockExpression) Type() types.Type {
return be.ReturnType return be.ReturnType
} }
func (be *BlockExpression) TokenLiteral() string { return be.Token.Literal } func (be *BlockExpression) TokenLiteral() string { return be.Token.Literal }
func (be *BlockExpression) Tok() token.Token { return be.Token }
func (be *BlockExpression) String() string { func (be *BlockExpression) String() string {
var builder strings.Builder var builder strings.Builder
@ -150,11 +158,14 @@ type IfExpression struct {
ReturnType types.Type ReturnType types.Type
} }
var _ Expression = &IfExpression{}
func (ie *IfExpression) expressionNode() {} func (ie *IfExpression) expressionNode() {}
func (ie *IfExpression) Type() types.Type { func (ie *IfExpression) Type() types.Type {
return ie.ReturnType return ie.ReturnType
} }
func (ie *IfExpression) TokenLiteral() string { return ie.Token.Literal } func (ie *IfExpression) TokenLiteral() string { return ie.Token.Literal }
func (ie *IfExpression) Tok() token.Token { return ie.Token }
func (ie *IfExpression) String() string { func (ie *IfExpression) String() string {
var builder strings.Builder var builder strings.Builder
@ -177,13 +188,16 @@ type VariableDeclaration struct {
Identifier string Identifier string
} }
var _ Expression = &VariableDeclaration{}
func (vd *VariableDeclaration) expressionNode() {} func (vd *VariableDeclaration) expressionNode() {}
func (vd *VariableDeclaration) Type() types.Type { func (vd *VariableDeclaration) Type() types.Type {
return vd.VariableType return types.Unit
} }
func (vd *VariableDeclaration) TokenLiteral() string { return vd.Token.Literal } func (vd *VariableDeclaration) TokenLiteral() string { return vd.Token.Literal }
func (vd *VariableDeclaration) Tok() token.Token { return vd.Token }
func (vd *VariableDeclaration) String() string { func (vd *VariableDeclaration) String() string {
return fmt.Sprintf("%s : %v = %s", vd.Identifier, vd.Type().Name(), vd.InitializingExpression) return fmt.Sprintf("%s : %v = %s", vd.Identifier, vd.VariableType.Name(), vd.InitializingExpression)
} }
type VariableReference struct { type VariableReference struct {
@ -192,14 +206,17 @@ type VariableReference struct {
VariableType types.Type VariableType types.Type
} }
var _ Expression = &VariableReference{}
func (vr *VariableReference) expressionNode() {} func (vr *VariableReference) expressionNode() {}
func (vr *VariableReference) Type() types.Type { func (vr *VariableReference) Type() types.Type {
return vr.VariableType return vr.VariableType
} }
func (vr *VariableReference) TokenLiteral() string { return vr.Token.Literal } func (vr *VariableReference) TokenLiteral() string { return vr.Token.Literal }
func (vr *VariableReference) Tok() token.Token { return vr.Token }
func (vr *VariableReference) String() string { func (vr *VariableReference) String() string {
return fmt.Sprintf("%s", vr.Identifier) return fmt.Sprintf("(%s :> %s)", vr.Identifier, vr.Type().Name())
} }
type AssignmentExpression struct { type AssignmentExpression struct {
@ -208,11 +225,14 @@ type AssignmentExpression struct {
Rhs Expression Rhs Expression
} }
var _ Expression = &AssignmentExpression{}
func (ae *AssignmentExpression) expressionNode() {} func (ae *AssignmentExpression) expressionNode() {}
func (ae *AssignmentExpression) Type() types.Type { func (ae *AssignmentExpression) Type() types.Type {
return types.Unit return types.Unit
} }
func (ae *AssignmentExpression) TokenLiteral() string { return ae.Token.Literal } func (ae *AssignmentExpression) TokenLiteral() string { return ae.Token.Literal }
func (ae *AssignmentExpression) Tok() token.Token { return ae.Token }
func (ae *AssignmentExpression) String() string { func (ae *AssignmentExpression) String() string {
return fmt.Sprintf("%s = %s", ae.Lhs.String(), ae.Rhs.String()) return fmt.Sprintf("%s = %s", ae.Lhs.String(), ae.Rhs.String())
} }

View File

@ -1,5 +1,5 @@
fn main() = { fn main() = {
hi: i64 = 4; hi := 4;
if hi == 2 in hi = 3 if hi == 2 in hi = 3
else hi = 2; else hi = 2;

View File

@ -121,6 +121,10 @@ func emitExpression(expr tast.Expression) (Operand, []Instruction) {
} }
instructions = append(instructions, Label(endOfIfLabel)) instructions = append(instructions, Label(endOfIfLabel))
return dst, instructions return dst, instructions
case *tast.AssignmentExpression:
case *tast.VariableDeclaration:
case *tast.VariableReference:
default:
panic(fmt.Sprintf("unexpected tast.Expression: %#v", expr))
} }
panic("unhandled tast.Expression case in ir emitter")
} }

View File

@ -10,8 +10,11 @@ import (
"robaertschi.xyz/robaertschi/tt/types" "robaertschi.xyz/robaertschi/tt/types"
) )
type Variables map[string]types.Type
type Checker struct { type Checker struct {
foundMain bool foundMain bool
functionVariables map[string]Variables
} }
func New() *Checker { func New() *Checker {
@ -52,7 +55,7 @@ func (c *Checker) CheckProgram(program *ast.Program) (*tast.Program, error) {
func (c *Checker) checkDeclaration(decl tast.Declaration) error { func (c *Checker) checkDeclaration(decl tast.Declaration) error {
switch decl := decl.(type) { switch decl := decl.(type) {
case *tast.FunctionDeclaration: case *tast.FunctionDeclaration:
err := c.checkExpression(decl.Body) err := c.checkExpression(c.functionVariables[decl.Name], decl.Body)
if err != nil { if err != nil {
return err return err
@ -67,15 +70,15 @@ func (c *Checker) checkDeclaration(decl tast.Declaration) error {
return errors.New("unhandled declaration in type checker") return errors.New("unhandled declaration in type checker")
} }
func (c *Checker) checkExpression(expr tast.Expression) error { func (c *Checker) checkExpression(vars Variables, expr tast.Expression) error {
switch expr := expr.(type) { switch expr := expr.(type) {
case *tast.IntegerExpression: case *tast.IntegerExpression:
return nil return nil
case *tast.BooleanExpression: case *tast.BooleanExpression:
return nil return nil
case *tast.BinaryExpression: case *tast.BinaryExpression:
lhsErr := c.checkExpression(expr.Lhs) lhsErr := c.checkExpression(vars, expr.Lhs)
rhsErr := c.checkExpression(expr.Rhs) rhsErr := c.checkExpression(vars, expr.Rhs)
var operandErr error var operandErr error
if lhsErr == nil && rhsErr == nil { if lhsErr == nil && rhsErr == nil {
if !expr.Lhs.Type().IsSameType(expr.Rhs.Type()) { if !expr.Lhs.Type().IsSameType(expr.Rhs.Type()) {
@ -90,32 +93,62 @@ func (c *Checker) checkExpression(expr tast.Expression) error {
errs := []error{} errs := []error{}
for _, expr := range expr.Expressions { for _, expr := range expr.Expressions {
errs = append(errs, c.checkExpression(expr)) errs = append(errs, c.checkExpression(vars, expr))
} }
if expr.ReturnExpression != nil { if expr.ReturnExpression != nil {
errs = append(errs, c.checkExpression(expr.ReturnExpression)) errs = append(errs, c.checkExpression(vars, expr.ReturnExpression))
} }
return errors.Join(errs...) return errors.Join(errs...)
case *tast.IfExpression: case *tast.IfExpression:
condErr := c.checkExpression(expr.Condition) condErr := c.checkExpression(vars, expr.Condition)
if condErr == nil { if condErr == nil {
if !expr.Condition.Type().IsSameType(types.Bool) { if !expr.Condition.Type().IsSameType(types.Bool) {
condErr = c.error(expr.Token, "the condition in the if should be a boolean, but got %q", expr.Condition.Type().Name()) condErr = c.error(expr.Token, "the condition in the if should be a boolean, but got %q", expr.Condition.Type().Name())
} }
} }
thenErr := c.checkExpression(expr.Then) thenErr := c.checkExpression(vars, expr.Then)
if expr.Else == nil { if expr.Else == nil {
return errors.Join(condErr, thenErr) return errors.Join(condErr, thenErr)
} }
elseErr := c.checkExpression(expr.Else) elseErr := c.checkExpression(vars, expr.Else)
if thenErr == nil && elseErr == nil { if thenErr == nil && elseErr == nil {
if !expr.Then.Type().IsSameType(expr.Else.Type()) { if !expr.Then.Type().IsSameType(expr.Else.Type()) {
thenErr = c.error(expr.Token, "the then branch of type %q does not match with the else branch of type %q", expr.Then.Type().Name(), expr.Else.Type().Name()) thenErr = c.error(expr.Token, "the then branch of type %q does not match with the else branch of type %q", expr.Then.Type().Name(), expr.Else.Type().Name())
} }
} }
return errors.Join(condErr, thenErr, elseErr) return errors.Join(condErr, thenErr, elseErr)
case *tast.AssignmentExpression:
varRef, ok := expr.Lhs.(*tast.VariableReference)
if !ok {
return c.error(expr.Token, "not a valid assignment target")
}
if !expr.Lhs.Type().IsSameType(expr.Lhs.Type()) {
return c.error(
expr.Rhs.Tok(),
"the assignment rhs has the wrong type, variable %q has type %q but got %q",
varRef.Identifier,
varRef.Type().Name(),
expr.Rhs.String(),
)
}
return nil
case *tast.VariableDeclaration:
if !expr.VariableType.IsSameType(expr.InitializingExpression.Type()) {
return c.error(expr.InitializingExpression.Tok(),
"initializing expression for variable %q has wrong type, expected %q but got %q",
expr.Identifier,
expr.VariableType.Name(),
expr.InitializingExpression.Type().Name(),
)
}
return nil
case *tast.VariableReference:
return nil
default:
panic(fmt.Sprintf("unexpected tast.Expression: %#v", expr))
} }
return fmt.Errorf("unhandled expression %T in type checker", expr) return fmt.Errorf("unhandled expression %T in type checker", expr)
} }

View File

@ -10,6 +10,7 @@ import (
) )
func (c *Checker) inferTypes(program *ast.Program) (*tast.Program, error) { func (c *Checker) inferTypes(program *ast.Program) (*tast.Program, error) {
c.functionVariables = make(map[string]Variables)
decls := []tast.Declaration{} decls := []tast.Declaration{}
errs := []error{} errs := []error{}
@ -28,7 +29,9 @@ func (c *Checker) inferTypes(program *ast.Program) (*tast.Program, error) {
func (c *Checker) inferDeclaration(decl ast.Declaration) (tast.Declaration, error) { func (c *Checker) inferDeclaration(decl ast.Declaration) (tast.Declaration, error) {
switch decl := decl.(type) { switch decl := decl.(type) {
case *ast.FunctionDeclaration: case *ast.FunctionDeclaration:
body, err := c.inferExpression(decl.Body) vars := make(Variables)
body, err := c.inferExpression(vars, decl.Body)
c.functionVariables[decl.Name] = vars
if err != nil { if err != nil {
return nil, err return nil, err
@ -39,7 +42,7 @@ func (c *Checker) inferDeclaration(decl ast.Declaration) (tast.Declaration, erro
return nil, errors.New("unhandled declaration in type inferer") return nil, errors.New("unhandled declaration in type inferer")
} }
func (c *Checker) inferExpression(expr ast.Expression) (tast.Expression, error) { func (c *Checker) inferExpression(vars Variables, expr ast.Expression) (tast.Expression, error) {
switch expr := expr.(type) { switch expr := expr.(type) {
case *ast.IntegerExpression: case *ast.IntegerExpression:
return &tast.IntegerExpression{Token: expr.Token, Value: expr.Value}, nil return &tast.IntegerExpression{Token: expr.Token, Value: expr.Value}, nil
@ -48,8 +51,8 @@ func (c *Checker) inferExpression(expr ast.Expression) (tast.Expression, error)
case *ast.ErrorExpression: case *ast.ErrorExpression:
return nil, c.error(expr.InvalidToken, "invalid expression") return nil, c.error(expr.InvalidToken, "invalid expression")
case *ast.BinaryExpression: case *ast.BinaryExpression:
lhs, lhsErr := c.inferExpression(expr.Lhs) lhs, lhsErr := c.inferExpression(vars, expr.Lhs)
rhs, rhsErr := c.inferExpression(expr.Rhs) rhs, rhsErr := c.inferExpression(vars, expr.Rhs)
var resultType types.Type var resultType types.Type
if lhsErr == nil && rhsErr == nil { if lhsErr == nil && rhsErr == nil {
if expr.Operator.IsBooleanOperator() { if expr.Operator.IsBooleanOperator() {
@ -65,7 +68,7 @@ func (c *Checker) inferExpression(expr ast.Expression) (tast.Expression, error)
errs := []error{} errs := []error{}
for _, expr := range expr.Expressions { for _, expr := range expr.Expressions {
newExpr, err := c.inferExpression(expr) newExpr, err := c.inferExpression(vars, expr)
if err != nil { if err != nil {
errs = append(errs, err) errs = append(errs, err)
} else { } else {
@ -76,7 +79,7 @@ func (c *Checker) inferExpression(expr ast.Expression) (tast.Expression, error)
var returnExpr tast.Expression var returnExpr tast.Expression
var returnType types.Type var returnType types.Type
if expr.ReturnExpression != nil { if expr.ReturnExpression != nil {
expr, err := c.inferExpression(expr.ReturnExpression) expr, err := c.inferExpression(vars, expr.ReturnExpression)
returnExpr = expr returnExpr = expr
if err != nil { if err != nil {
errs = append(errs, err) errs = append(errs, err)
@ -95,15 +98,74 @@ func (c *Checker) inferExpression(expr ast.Expression) (tast.Expression, error)
}, errors.Join(errs...) }, errors.Join(errs...)
case *ast.IfExpression: case *ast.IfExpression:
cond, condErr := c.inferExpression(expr.Condition) cond, condErr := c.inferExpression(vars, expr.Condition)
then, thenErr := c.inferExpression(expr.Then) then, thenErr := c.inferExpression(vars, expr.Then)
if expr.Else != nil { if expr.Else != nil {
elseExpr, elseErr := c.inferExpression(expr.Else) elseExpr, elseErr := c.inferExpression(vars, expr.Else)
return &tast.IfExpression{Token: expr.Token, Condition: cond, Then: then, Else: elseExpr, ReturnType: then.Type()}, errors.Join(condErr, thenErr, elseErr) return &tast.IfExpression{Token: expr.Token, Condition: cond, Then: then, Else: elseExpr, ReturnType: then.Type()}, errors.Join(condErr, thenErr, elseErr)
} }
return &tast.IfExpression{Token: expr.Token, Condition: cond, Then: then, Else: nil, ReturnType: types.Unit}, errors.Join(condErr, thenErr) return &tast.IfExpression{Token: expr.Token, Condition: cond, Then: then, Else: nil, ReturnType: types.Unit}, errors.Join(condErr, thenErr)
case *ast.AssignmentExpression:
varRef, ok := expr.Lhs.(*ast.VariableReference)
if !ok {
return &tast.AssignmentExpression{}, c.error(expr.Token, "not a valid assignment target")
}
rhs, err := c.inferExpression(vars, expr.Rhs)
if err != nil {
return &tast.AssignmentExpression{}, err
}
varRefT, err := c.inferExpression(vars, varRef)
return &tast.AssignmentExpression{Lhs: varRefT, Rhs: rhs, Token: expr.Token}, err
case *ast.VariableDeclaration:
vd := &tast.VariableDeclaration{}
var t types.Type
var initializingExpr tast.Expression
if expr.Type != "" {
var ok bool
t, ok = types.From(expr.Type)
if !ok {
return vd, c.error(expr.Token, "could not find the type %q", expr.Type)
}
var err error
initializingExpr, err = c.inferExpression(vars, expr.InitializingExpression)
if err != nil {
return vd, err
}
} else {
var err error
initializingExpr, err = c.inferExpression(vars, expr.InitializingExpression)
if err != nil {
return vd, err
}
t = initializingExpr.Type()
}
vd.VariableType = t
vars[expr.Identifier] = t
vd.InitializingExpression = initializingExpr
vd.Token = expr.Token
vd.Identifier = expr.Identifier
return vd, nil
case *ast.VariableReference:
vr := &tast.VariableReference{Identifier: expr.Identifier, Token: expr.Token}
t, ok := vars[expr.Identifier]
if !ok {
return vr, c.error(expr.Token, "could not get type for variable %q", vr.Identifier)
}
vr.VariableType = t
return vr, nil
default:
panic(fmt.Sprintf("unexpected ast.Expression: %#v", expr))
} }
return nil, fmt.Errorf("unhandled expression in type inferer") return nil, fmt.Errorf("unhandled expression in type inferer")
} }

View File

@ -1,77 +1,149 @@
package typechecker package typechecker
import ( import (
"errors"
"fmt" "fmt"
"robaertschi.xyz/robaertschi/tt/ast" "robaertschi.xyz/robaertschi/tt/ast"
"robaertschi.xyz/robaertschi/tt/types" "robaertschi.xyz/robaertschi/tt/token"
) )
type Variable struct { type Var struct {
Name string Name string
Type types.Type FromCurrentScope bool
} }
type Scope struct { type Scope struct {
Variables map[string]Variable Variables map[string]Var
ParentScope *Scope UniqueId int64
} }
func (s *Scope) Get(name string) (Variable, bool) { func errorf(t token.Token, format string, args ...any) error {
return fmt.Errorf("%s:%d:%d %s", t.Loc.File, t.Loc.Line, t.Loc.Col, fmt.Sprintf(format, args...))
}
func copyScope(s *Scope) Scope {
newVars := make(map[string]Var)
for k, v := range s.Variables {
newVars[k] = Var{Name: v.Name, FromCurrentScope: false}
}
return Scope{Variables: newVars}
}
func (s *Scope) Get(name string) (Var, bool) {
v, ok := s.Variables[name] v, ok := s.Variables[name]
if ok { if ok {
return v, true return v, true
} }
if s.ParentScope != nil { return Var{}, false
return s.ParentScope.Get(name)
} }
return Variable{}, false func (s *Scope) Set(name string, uniqName string) {
} s.Variables[name] = Var{Name: uniqName, FromCurrentScope: true}
func (s *Scope) Set(name string, t types.Type) {
s.Variables[name] = Variable{Name: name, Type: t}
} }
func (s *Scope) Has(name string) bool { func (s *Scope) Has(name string) bool {
_, ok := s.Variables[name] _, ok := s.Variables[name]
if !ok && s.ParentScope != nil {
return s.ParentScope.Has(name)
}
return ok return ok
} }
func VarResolve(p *ast.Program) (Scope, error) { func (s *Scope) HasInCurrent(name string) bool {
s := Scope{Variables: make(map[string]Variable)} v, ok := s.Variables[name]
if !ok {
return false
}
return v.FromCurrentScope
}
func VarResolve(p *ast.Program) (map[string]Scope, error) {
functionToScope := make(map[string]Scope)
for _, d := range p.Declarations { for _, d := range p.Declarations {
switch d := d.(type) { switch d := d.(type) {
case *ast.FunctionDeclaration: case *ast.FunctionDeclaration:
s := Scope{Variables: make(map[string]Var)}
err := VarResolveExpr(&s, d.Body) err := VarResolveExpr(&s, d.Body)
functionToScope[d.Name] = s
if err != nil { if err != nil {
return s, err return functionToScope, err
} }
} }
} }
return s, nil return functionToScope, nil
} }
func VarResolveExpr(s *Scope, e ast.Expression) error { func VarResolveExpr(s *Scope, e ast.Expression) error {
switch e := e.(type) { switch e := e.(type) {
case *ast.ErrorExpression: case *ast.ErrorExpression:
// NOTE: The Checker will take care of this
return nil
case *ast.AssignmentExpression: case *ast.AssignmentExpression:
err := VarResolveExpr(s, e.Lhs)
if err != nil {
return err
}
err = VarResolveExpr(s, e.Rhs)
if err != nil {
return err
}
case *ast.BinaryExpression: case *ast.BinaryExpression:
err := VarResolveExpr(s, e.Lhs)
if err != nil {
return err
}
err = VarResolveExpr(s, e.Rhs)
if err != nil {
return err
}
case *ast.BlockExpression: case *ast.BlockExpression:
case *ast.BooleanExpression: newS := copyScope(s)
errs := []error{}
for _, e := range e.Expressions {
errs = append(errs, VarResolveExpr(&newS, e))
}
errs = append(errs, VarResolveExpr(&newS, e.ReturnExpression))
return errors.Join(errs...)
case *ast.IfExpression: case *ast.IfExpression:
case *ast.IntegerExpression: err := VarResolveExpr(s, e.Condition)
if err != nil {
return err
}
err = VarResolveExpr(s, e.Then)
if err != nil {
return err
}
if e.Else != nil {
err = VarResolveExpr(s, e.Else)
if err != nil {
return err
}
}
case *ast.VariableDeclaration: case *ast.VariableDeclaration:
if s.HasInCurrent(e.Identifier) {
return errorf(e.Token, "variable %q redifinded", e.Identifier)
}
uniqName := fmt.Sprintf("%s.%d", e.Identifier, s.UniqueId)
s.UniqueId += 1
s.Set(e.Identifier, uniqName)
case *ast.VariableReference: case *ast.VariableReference:
v, ok := s.Get(e.Identifier)
if !ok {
return errorf(e.Token, "variable %q is not declared", e.Identifier)
}
e.Identifier = v.Name
case *ast.BooleanExpression:
case *ast.IntegerExpression:
default: default:
panic(fmt.Sprintf("unexpected ast.Expression: %#v", e)) panic(fmt.Sprintf("unexpected ast.Expression: %#v", e))
} }

View File

@ -45,6 +45,15 @@ func (ti *TypeId) Name() string {
return ti.name return ti.name
} }
var types map[string]Type = make(map[string]Type)
func New(id int64, name string) Type { func New(id int64, name string) Type {
return &TypeId{id: id, name: name} typeId := &TypeId{id: id, name: name}
types[name] = typeId
return typeId
}
func From(name string) (Type, bool) {
t, ok := types[name]
return t, ok
} }