mirror of
https://github.com/RoBaertschi/tt.git
synced 2025-04-15 21:43:30 +00:00
234 lines
6.6 KiB
Go
234 lines
6.6 KiB
Go
package typechecker
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
|
|
"robaertschi.xyz/robaertschi/tt/ast"
|
|
"robaertschi.xyz/robaertschi/tt/tast"
|
|
"robaertschi.xyz/robaertschi/tt/types"
|
|
)
|
|
|
|
func (c *Checker) inferTypes(program *ast.Program) (*tast.Program, error) {
|
|
c.functionVariables = make(map[string]Variables)
|
|
decls := []tast.Declaration{}
|
|
errs := []error{}
|
|
vars := make(Variables)
|
|
|
|
funcToParams := make(map[string][]tast.Parameter)
|
|
|
|
for _, decl := range program.Declarations {
|
|
switch decl := decl.(type) {
|
|
case *ast.FunctionDeclaration:
|
|
parameters := []tast.Parameter{}
|
|
for _, param := range decl.Parameters {
|
|
t, ok := types.From(param.Type)
|
|
if !ok {
|
|
return nil, c.error(decl.Token, "could not find the type %q for argument %q", param.Type, param.Name)
|
|
}
|
|
vars[param.Name] = t
|
|
parameters = append(parameters, tast.Parameter{Name: param.Name, Type: t})
|
|
}
|
|
|
|
t, ok := types.From(decl.ReturnType)
|
|
if !ok {
|
|
return nil, c.error(decl.Token, "invalid type %q", decl.ReturnType)
|
|
}
|
|
|
|
parameterTypes := []types.Type{}
|
|
|
|
for _, param := range parameters {
|
|
parameterTypes = append(parameterTypes, param.Type)
|
|
}
|
|
|
|
vars[decl.Name] = &types.FunctionType{ReturnType: t, Parameters: parameterTypes}
|
|
funcToParams[decl.Name] = parameters
|
|
}
|
|
}
|
|
|
|
for _, decl := range program.Declarations {
|
|
decl, err := c.inferDeclaration(funcToParams, copyVars(vars), decl)
|
|
if err == nil {
|
|
decls = append(decls, decl)
|
|
} else {
|
|
errs = append(errs, err)
|
|
}
|
|
}
|
|
|
|
return &tast.Program{Declarations: decls}, errors.Join(errs...)
|
|
}
|
|
|
|
func (c *Checker) inferDeclaration(funcToParams map[string][]tast.Parameter, vars Variables, decl ast.Declaration) (tast.Declaration, error) {
|
|
switch decl := decl.(type) {
|
|
case *ast.FunctionDeclaration:
|
|
body, err := c.inferExpression(vars, decl.Body)
|
|
c.functionVariables[decl.Name] = vars
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &tast.FunctionDeclaration{Token: decl.Token, Parameters: funcToParams[decl.Name], Body: body, ReturnType: vars[decl.Name], Name: decl.Name}, nil
|
|
}
|
|
return nil, errors.New("unhandled declaration in type inferer")
|
|
}
|
|
|
|
func (c *Checker) inferExpression(vars Variables, expr ast.Expression) (tast.Expression, error) {
|
|
switch expr := expr.(type) {
|
|
case *ast.IntegerExpression:
|
|
return &tast.IntegerExpression{Token: expr.Token, Value: expr.Value}, nil
|
|
case *ast.BooleanExpression:
|
|
return &tast.BooleanExpression{Token: expr.Token, Value: expr.Value}, nil
|
|
case *ast.ErrorExpression:
|
|
return nil, c.error(expr.InvalidToken, "invalid expression")
|
|
case *ast.BinaryExpression:
|
|
lhs, lhsErr := c.inferExpression(vars, expr.Lhs)
|
|
rhs, rhsErr := c.inferExpression(vars, expr.Rhs)
|
|
var resultType types.Type
|
|
if lhsErr == nil && rhsErr == nil {
|
|
if expr.Operator.IsBooleanOperator() {
|
|
resultType = types.Bool
|
|
} else {
|
|
resultType = lhs.Type()
|
|
}
|
|
}
|
|
|
|
return &tast.BinaryExpression{Lhs: lhs, Rhs: rhs, Operator: expr.Operator, Token: expr.Token, ResultType: resultType}, errors.Join(lhsErr, rhsErr)
|
|
case *ast.BlockExpression:
|
|
expressions := []tast.Expression{}
|
|
errs := []error{}
|
|
|
|
for _, expr := range expr.Expressions {
|
|
newExpr, err := c.inferExpression(vars, expr)
|
|
if err != nil {
|
|
errs = append(errs, err)
|
|
} else {
|
|
expressions = append(expressions, newExpr)
|
|
}
|
|
}
|
|
|
|
var returnExpr tast.Expression
|
|
var returnType types.Type
|
|
if expr.ReturnExpression != nil {
|
|
expr, err := c.inferExpression(vars, expr.ReturnExpression)
|
|
returnExpr = expr
|
|
if err != nil {
|
|
errs = append(errs, err)
|
|
} else {
|
|
returnType = returnExpr.Type()
|
|
}
|
|
} else {
|
|
returnType = types.Unit
|
|
}
|
|
|
|
return &tast.BlockExpression{
|
|
Token: expr.Token,
|
|
Expressions: expressions,
|
|
ReturnType: returnType,
|
|
ReturnExpression: returnExpr,
|
|
}, errors.Join(errs...)
|
|
|
|
case *ast.IfExpression:
|
|
cond, condErr := c.inferExpression(vars, expr.Condition)
|
|
then, thenErr := c.inferExpression(vars, expr.Then)
|
|
|
|
if expr.Else != nil {
|
|
elseExpr, elseErr := c.inferExpression(vars, expr.Else)
|
|
|
|
return &tast.IfExpression{Token: expr.Token, Condition: cond, Then: then, Else: elseExpr, ReturnType: then.Type()}, errors.Join(condErr, thenErr, elseErr)
|
|
}
|
|
return &tast.IfExpression{Token: expr.Token, Condition: cond, Then: then, Else: nil, ReturnType: types.Unit}, errors.Join(condErr, thenErr)
|
|
case *ast.AssignmentExpression:
|
|
varRef, ok := expr.Lhs.(*ast.VariableReference)
|
|
if !ok {
|
|
return &tast.AssignmentExpression{}, c.error(expr.Token, "not a valid assignment target")
|
|
}
|
|
|
|
rhs, err := c.inferExpression(vars, expr.Rhs)
|
|
if err != nil {
|
|
return &tast.AssignmentExpression{}, err
|
|
}
|
|
|
|
varRefT, err := c.inferExpression(vars, varRef)
|
|
return &tast.AssignmentExpression{Lhs: varRefT, Rhs: rhs, Token: expr.Token}, err
|
|
case *ast.VariableDeclaration:
|
|
vd := &tast.VariableDeclaration{}
|
|
var t types.Type
|
|
var initializingExpr tast.Expression
|
|
|
|
if expr.Type != "" {
|
|
var ok bool
|
|
t, ok = types.From(expr.Type)
|
|
if !ok {
|
|
return vd, c.error(expr.Token, "could not find the type %q", expr.Type)
|
|
}
|
|
var err error
|
|
initializingExpr, err = c.inferExpression(vars, expr.InitializingExpression)
|
|
if err != nil {
|
|
return vd, err
|
|
}
|
|
} else {
|
|
var err error
|
|
initializingExpr, err = c.inferExpression(vars, expr.InitializingExpression)
|
|
if err != nil {
|
|
return vd, err
|
|
}
|
|
|
|
t = initializingExpr.Type()
|
|
}
|
|
|
|
vd.VariableType = t
|
|
vars[expr.Identifier] = t
|
|
|
|
vd.InitializingExpression = initializingExpr
|
|
vd.Token = expr.Token
|
|
vd.Identifier = expr.Identifier
|
|
return vd, nil
|
|
case *ast.VariableReference:
|
|
vr := &tast.VariableReference{Identifier: expr.Identifier, Token: expr.Token}
|
|
|
|
t, ok := vars[expr.Identifier]
|
|
if !ok {
|
|
return vr, c.error(expr.Token, "could not get type for variable %q", vr.Identifier)
|
|
}
|
|
|
|
vr.VariableType = t
|
|
|
|
return vr, nil
|
|
case *ast.FunctionCall:
|
|
fc := &tast.FunctionCall{Identifier: expr.Identifier, Token: expr.Token}
|
|
|
|
t, ok := vars[expr.Identifier]
|
|
if !ok {
|
|
return fc, c.error(expr.Token, "could not get type for function %q", fc.Identifier)
|
|
}
|
|
|
|
funcType, ok := t.(*types.FunctionType)
|
|
if !ok {
|
|
return fc, c.error(expr.Token, "tried to call non function variable %q with type %q", expr.Identifier, t.Name())
|
|
}
|
|
|
|
fc.ReturnType = funcType.ReturnType
|
|
|
|
args := []tast.Expression{}
|
|
errs := []error{}
|
|
|
|
for _, arg := range expr.Arguments {
|
|
inferredArg, err := c.inferExpression(vars, arg)
|
|
errs = append(errs, err)
|
|
|
|
if err == nil {
|
|
args = append(args, inferredArg)
|
|
}
|
|
}
|
|
|
|
fc.Arguments = args
|
|
|
|
return fc, errors.Join(errs...)
|
|
|
|
default:
|
|
panic(fmt.Sprintf("unexpected ast.Expression: %#v", expr))
|
|
}
|
|
return nil, fmt.Errorf("unhandled expression in type inferer")
|
|
}
|