added function calls util typechecker WIP

still have to add support in ttir and both backends
This commit is contained in:
Robin Bärtschi 2025-03-11 12:07:46 +01:00
parent 22aa1ac3ed
commit bcfa9fbde8
8 changed files with 114 additions and 25 deletions

View File

@ -64,9 +64,10 @@ type FunctionDeclaration struct {
Body Expression
Name string
Parameters []Parameter
ReturnType Type
}
func ArgsToString(args []Parameter) string {
func ParamsToString(args []Parameter) string {
var b strings.Builder
for _, arg := range args {
@ -80,7 +81,7 @@ func (fd *FunctionDeclaration) declarationNode() {}
func (fd *FunctionDeclaration) TokenLiteral() string { return fd.Token.Literal }
func (fd *FunctionDeclaration) Tok() token.Token { return fd.Token }
func (fd *FunctionDeclaration) String() string {
return fmt.Sprintf("fn %v(%v) = %v;", fd.Name, ArgsToString(fd.Parameters), fd.Body.String())
return fmt.Sprintf("fn %v(%v): %v = %v;", fd.Name, ParamsToString(fd.Parameters), fd.ReturnType, fd.Body.String())
}
// Represents a Expression that we failed to parse

View File

@ -236,6 +236,14 @@ func (p *Parser) parseDeclaration() ast.Declaration {
if ok, _ := p.expectPeek(token.CloseParen); !ok {
return nil
}
if ok, _ := p.expectPeek(token.Colon); !ok {
return nil
}
p.nextToken()
t, ok := p.parseType()
if !ok {
return nil
}
if ok, _ := p.expectPeek(token.Equal); !ok {
return nil
}
@ -251,6 +259,7 @@ func (p *Parser) parseDeclaration() ast.Declaration {
Name: name,
Body: expr,
Parameters: params,
ReturnType: t,
}
}

View File

@ -52,12 +52,12 @@ func (p *Program) String() string {
return builder.String()
}
type Argument struct {
type Parameter struct {
Name string
Type types.Type
}
func ArgsToString(args []Argument) string {
func ArgsToString(args []Parameter) string {
var b strings.Builder
for _, arg := range args {
@ -71,7 +71,7 @@ type FunctionDeclaration struct {
Token token.Token // The token.FN
Body Expression
Name string
Args []Argument
Parameters []Parameter
ReturnType types.Type
}
@ -81,7 +81,7 @@ func (fd *FunctionDeclaration) declarationNode() {}
func (fd *FunctionDeclaration) TokenLiteral() string { return fd.Token.Literal }
func (fd *FunctionDeclaration) Tok() token.Token { return fd.Token }
func (fd *FunctionDeclaration) String() string {
return fmt.Sprintf("fn %v(%v): %v = %v;", fd.Name, ArgsToString(fd.Args), fd.ReturnType.Name(), fd.Body.String())
return fmt.Sprintf("fn %v(%v): %v = %v;", fd.Name, ArgsToString(fd.Parameters), fd.ReturnType.Name(), fd.Body.String())
}
type IntegerExpression struct {

View File

@ -1,4 +1,4 @@
fn main() = {
fn main(): i64 = {
i := 5;
@ -11,7 +11,7 @@ fn main() = {
test2(3)
};
fn test2(hello: i64) = {
fn test2(hello: i64): i64 = {
hello // Comment test
};

View File

@ -47,7 +47,7 @@ func emitFunction(function *tast.FunctionDeclaration) *Function {
arguments := []string{}
for _, arg := range function.Args {
for _, arg := range function.Parameters {
arguments = append(arguments, arg.Name)
}

View File

@ -12,6 +12,16 @@ import (
type Variables map[string]types.Type
func copyVars(vars Variables) Variables {
newVars := make(Variables)
for k, v := range vars {
newVars[k] = v
}
return newVars
}
type Checker struct {
foundMain bool
functionVariables map[string]Variables
@ -152,6 +162,22 @@ func (c *Checker) checkExpression(vars Variables, expr tast.Expression) error {
return nil
case *tast.VariableReference:
return nil
case *tast.FunctionCall:
functionType := vars[expr.Identifier].(*types.FunctionType)
if len(expr.Arguments) != len(functionType.Parameters) {
return c.error(expr.Token, "invalid amount of arguments for function %q, expected %d but got %d", expr.Identifier, len(functionType.Parameters), len(expr.Arguments))
}
errs := []error{}
for i, param := range functionType.Parameters {
e := expr.Arguments[i]
if !e.Type().IsSameType(param) {
errs = append(errs, c.error(e.Tok(), "invalid type for parameter, expected %q but got %q", param.Name(), e.Type().Name()))
}
}
return errors.Join(errs...)
default:
panic(fmt.Sprintf("unexpected tast.Expression: %#v", expr))
}

View File

@ -13,9 +13,41 @@ func (c *Checker) inferTypes(program *ast.Program) (*tast.Program, error) {
c.functionVariables = make(map[string]Variables)
decls := []tast.Declaration{}
errs := []error{}
vars := make(Variables)
funcToParams := make(map[string][]tast.Parameter)
for _, decl := range program.Declarations {
decl, err := c.inferDeclaration(decl)
switch decl := decl.(type) {
case *ast.FunctionDeclaration:
parameters := []tast.Parameter{}
for _, param := range decl.Parameters {
t, ok := types.From(param.Type)
if !ok {
return nil, c.error(decl.Token, "could not find the type %q for argument %q", param.Type, param.Name)
}
vars[param.Name] = t
parameters = append(parameters, tast.Parameter{Name: param.Name, Type: t})
}
t, ok := types.From(decl.ReturnType)
if !ok {
return nil, c.error(decl.Token, "invalid type %q", decl.ReturnType)
}
parameterTypes := []types.Type{}
for _, param := range parameters {
parameterTypes = append(parameterTypes, param.Type)
}
vars[decl.Name] = &types.FunctionType{ReturnType: t, Parameters: parameterTypes}
funcToParams[decl.Name] = parameters
}
}
for _, decl := range program.Declarations {
decl, err := c.inferDeclaration(funcToParams, copyVars(vars), decl)
if err == nil {
decls = append(decls, decl)
} else {
@ -26,20 +58,9 @@ func (c *Checker) inferTypes(program *ast.Program) (*tast.Program, error) {
return &tast.Program{Declarations: decls}, errors.Join(errs...)
}
func (c *Checker) inferDeclaration(decl ast.Declaration) (tast.Declaration, error) {
func (c *Checker) inferDeclaration(funcToParams map[string][]tast.Parameter, vars Variables, decl ast.Declaration) (tast.Declaration, error) {
switch decl := decl.(type) {
case *ast.FunctionDeclaration:
vars := make(Variables)
parameters := []tast.Parameter{}
for _, param := range decl.Parameters {
t, ok := types.From(param.Type)
if !ok {
return nil, c.error(decl.Token, "could not find the type %q for argument %q", param.Type, param.Name)
}
vars[param.Name] = t
parameters = append(parameters, tast.Parameter{Name: param.Name, Type: t})
}
// vars[decl.Name] = &types.FunctionType{ReturnType: }
body, err := c.inferExpression(vars, decl.Body)
c.functionVariables[decl.Name] = vars
@ -47,7 +68,7 @@ func (c *Checker) inferDeclaration(decl ast.Declaration) (tast.Declaration, erro
return nil, err
}
return &tast.FunctionDeclaration{Token: decl.Token, Args: parameters, Body: body, ReturnType: body.Type(), Name: decl.Name}, nil
return &tast.FunctionDeclaration{Token: decl.Token, Parameters: funcToParams[decl.Name], Body: body, ReturnType: vars[decl.Name], Name: decl.Name}, nil
}
return nil, errors.New("unhandled declaration in type inferer")
}
@ -181,6 +202,30 @@ func (c *Checker) inferExpression(vars Variables, expr ast.Expression) (tast.Exp
if !ok {
return fc, c.error(expr.Token, "could not get type for function %q", fc.Identifier)
}
funcType, ok := t.(*types.FunctionType)
if !ok {
return fc, c.error(expr.Token, "tried to call non function variable %q with type %q", expr.Identifier, t.Name())
}
fc.ReturnType = funcType.ReturnType
args := []tast.Expression{}
errs := []error{}
for _, arg := range expr.Arguments {
inferredArg, err := c.inferExpression(vars, arg)
errs = append(errs, err)
if err == nil {
args = append(args, inferredArg)
}
}
fc.Arguments = args
return fc, errors.Join(errs...)
default:
panic(fmt.Sprintf("unexpected ast.Expression: %#v", expr))
}

View File

@ -69,6 +69,15 @@ func (s *Scope) SetUniq(name string) string {
func VarResolve(p *ast.Program) (map[string]Scope, error) {
functionToScope := make(map[string]Scope)
functions := Scope{Variables: make(map[string]Var)}
for _, d := range p.Declarations {
switch d := d.(type) {
case *ast.FunctionDeclaration:
functions.Set(d.Name, d.Name)
default:
}
}
for _, d := range p.Declarations {
switch d := d.(type) {
@ -78,8 +87,7 @@ func VarResolve(p *ast.Program) (map[string]Scope, error) {
return functionToScope, errorf(d.Token, "duplicate function name %q", d.Name)
}
s := Scope{Variables: make(map[string]Var)}
s.Set(d.Name, d.Name)
s := copyScope(&functions)
for i, param := range d.Parameters {
uniq := s.SetUniq(param.Name)
d.Parameters[i].Name = uniq