diff --git a/ast/ast.go b/ast/ast.go index b81826c..2845fc8 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -64,9 +64,10 @@ type FunctionDeclaration struct { Body Expression Name string Parameters []Parameter + ReturnType Type } -func ArgsToString(args []Parameter) string { +func ParamsToString(args []Parameter) string { var b strings.Builder for _, arg := range args { @@ -80,7 +81,7 @@ func (fd *FunctionDeclaration) declarationNode() {} func (fd *FunctionDeclaration) TokenLiteral() string { return fd.Token.Literal } func (fd *FunctionDeclaration) Tok() token.Token { return fd.Token } func (fd *FunctionDeclaration) String() string { - return fmt.Sprintf("fn %v(%v) = %v;", fd.Name, ArgsToString(fd.Parameters), fd.Body.String()) + return fmt.Sprintf("fn %v(%v): %v = %v;", fd.Name, ParamsToString(fd.Parameters), fd.ReturnType, fd.Body.String()) } // Represents a Expression that we failed to parse diff --git a/parser/parser.go b/parser/parser.go index 30b03d9..320a874 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -236,6 +236,14 @@ func (p *Parser) parseDeclaration() ast.Declaration { if ok, _ := p.expectPeek(token.CloseParen); !ok { return nil } + if ok, _ := p.expectPeek(token.Colon); !ok { + return nil + } + p.nextToken() + t, ok := p.parseType() + if !ok { + return nil + } if ok, _ := p.expectPeek(token.Equal); !ok { return nil } @@ -251,6 +259,7 @@ func (p *Parser) parseDeclaration() ast.Declaration { Name: name, Body: expr, Parameters: params, + ReturnType: t, } } diff --git a/tast/tast.go b/tast/tast.go index 1902f13..3d9ff6e 100644 --- a/tast/tast.go +++ b/tast/tast.go @@ -52,12 +52,12 @@ func (p *Program) String() string { return builder.String() } -type Argument struct { +type Parameter struct { Name string Type types.Type } -func ArgsToString(args []Argument) string { +func ArgsToString(args []Parameter) string { var b strings.Builder for _, arg := range args { @@ -71,7 +71,7 @@ type FunctionDeclaration struct { Token token.Token // The token.FN Body Expression Name string - Args []Argument + Parameters []Parameter ReturnType types.Type } @@ -81,7 +81,7 @@ func (fd *FunctionDeclaration) declarationNode() {} func (fd *FunctionDeclaration) TokenLiteral() string { return fd.Token.Literal } func (fd *FunctionDeclaration) Tok() token.Token { return fd.Token } func (fd *FunctionDeclaration) String() string { - return fmt.Sprintf("fn %v(%v): %v = %v;", fd.Name, ArgsToString(fd.Args), fd.ReturnType.Name(), fd.Body.String()) + return fmt.Sprintf("fn %v(%v): %v = %v;", fd.Name, ArgsToString(fd.Parameters), fd.ReturnType.Name(), fd.Body.String()) } type IntegerExpression struct { diff --git a/test.tt b/test.tt index b0bbf95..682d676 100644 --- a/test.tt +++ b/test.tt @@ -1,4 +1,4 @@ -fn main() = { +fn main(): i64 = { i := 5; @@ -11,7 +11,7 @@ fn main() = { test2(3) }; -fn test2(hello: i64) = { +fn test2(hello: i64): i64 = { hello // Comment test }; diff --git a/ttir/emit.go b/ttir/emit.go index f0b749f..aab4be3 100644 --- a/ttir/emit.go +++ b/ttir/emit.go @@ -47,7 +47,7 @@ func emitFunction(function *tast.FunctionDeclaration) *Function { arguments := []string{} - for _, arg := range function.Args { + for _, arg := range function.Parameters { arguments = append(arguments, arg.Name) } diff --git a/typechecker/check.go b/typechecker/check.go index 35f73d0..b660a7c 100644 --- a/typechecker/check.go +++ b/typechecker/check.go @@ -12,6 +12,16 @@ import ( type Variables map[string]types.Type +func copyVars(vars Variables) Variables { + newVars := make(Variables) + + for k, v := range vars { + newVars[k] = v + } + + return newVars +} + type Checker struct { foundMain bool functionVariables map[string]Variables @@ -152,6 +162,22 @@ func (c *Checker) checkExpression(vars Variables, expr tast.Expression) error { return nil case *tast.VariableReference: return nil + case *tast.FunctionCall: + functionType := vars[expr.Identifier].(*types.FunctionType) + if len(expr.Arguments) != len(functionType.Parameters) { + return c.error(expr.Token, "invalid amount of arguments for function %q, expected %d but got %d", expr.Identifier, len(functionType.Parameters), len(expr.Arguments)) + } + + errs := []error{} + + for i, param := range functionType.Parameters { + e := expr.Arguments[i] + if !e.Type().IsSameType(param) { + errs = append(errs, c.error(e.Tok(), "invalid type for parameter, expected %q but got %q", param.Name(), e.Type().Name())) + } + } + + return errors.Join(errs...) default: panic(fmt.Sprintf("unexpected tast.Expression: %#v", expr)) } diff --git a/typechecker/infer.go b/typechecker/infer.go index 0f2825f..328672f 100644 --- a/typechecker/infer.go +++ b/typechecker/infer.go @@ -13,9 +13,41 @@ func (c *Checker) inferTypes(program *ast.Program) (*tast.Program, error) { c.functionVariables = make(map[string]Variables) decls := []tast.Declaration{} errs := []error{} + vars := make(Variables) + + funcToParams := make(map[string][]tast.Parameter) for _, decl := range program.Declarations { - decl, err := c.inferDeclaration(decl) + switch decl := decl.(type) { + case *ast.FunctionDeclaration: + parameters := []tast.Parameter{} + for _, param := range decl.Parameters { + t, ok := types.From(param.Type) + if !ok { + return nil, c.error(decl.Token, "could not find the type %q for argument %q", param.Type, param.Name) + } + vars[param.Name] = t + parameters = append(parameters, tast.Parameter{Name: param.Name, Type: t}) + } + + t, ok := types.From(decl.ReturnType) + if !ok { + return nil, c.error(decl.Token, "invalid type %q", decl.ReturnType) + } + + parameterTypes := []types.Type{} + + for _, param := range parameters { + parameterTypes = append(parameterTypes, param.Type) + } + + vars[decl.Name] = &types.FunctionType{ReturnType: t, Parameters: parameterTypes} + funcToParams[decl.Name] = parameters + } + } + + for _, decl := range program.Declarations { + decl, err := c.inferDeclaration(funcToParams, copyVars(vars), decl) if err == nil { decls = append(decls, decl) } else { @@ -26,20 +58,9 @@ func (c *Checker) inferTypes(program *ast.Program) (*tast.Program, error) { return &tast.Program{Declarations: decls}, errors.Join(errs...) } -func (c *Checker) inferDeclaration(decl ast.Declaration) (tast.Declaration, error) { +func (c *Checker) inferDeclaration(funcToParams map[string][]tast.Parameter, vars Variables, decl ast.Declaration) (tast.Declaration, error) { switch decl := decl.(type) { case *ast.FunctionDeclaration: - vars := make(Variables) - parameters := []tast.Parameter{} - for _, param := range decl.Parameters { - t, ok := types.From(param.Type) - if !ok { - return nil, c.error(decl.Token, "could not find the type %q for argument %q", param.Type, param.Name) - } - vars[param.Name] = t - parameters = append(parameters, tast.Parameter{Name: param.Name, Type: t}) - } - // vars[decl.Name] = &types.FunctionType{ReturnType: } body, err := c.inferExpression(vars, decl.Body) c.functionVariables[decl.Name] = vars @@ -47,7 +68,7 @@ func (c *Checker) inferDeclaration(decl ast.Declaration) (tast.Declaration, erro return nil, err } - return &tast.FunctionDeclaration{Token: decl.Token, Args: parameters, Body: body, ReturnType: body.Type(), Name: decl.Name}, nil + return &tast.FunctionDeclaration{Token: decl.Token, Parameters: funcToParams[decl.Name], Body: body, ReturnType: vars[decl.Name], Name: decl.Name}, nil } return nil, errors.New("unhandled declaration in type inferer") } @@ -181,6 +202,30 @@ func (c *Checker) inferExpression(vars Variables, expr ast.Expression) (tast.Exp if !ok { return fc, c.error(expr.Token, "could not get type for function %q", fc.Identifier) } + + funcType, ok := t.(*types.FunctionType) + if !ok { + return fc, c.error(expr.Token, "tried to call non function variable %q with type %q", expr.Identifier, t.Name()) + } + + fc.ReturnType = funcType.ReturnType + + args := []tast.Expression{} + errs := []error{} + + for _, arg := range expr.Arguments { + inferredArg, err := c.inferExpression(vars, arg) + errs = append(errs, err) + + if err == nil { + args = append(args, inferredArg) + } + } + + fc.Arguments = args + + return fc, errors.Join(errs...) + default: panic(fmt.Sprintf("unexpected ast.Expression: %#v", expr)) } diff --git a/typechecker/variable_resolution.go b/typechecker/variable_resolution.go index 1730308..ccdc673 100644 --- a/typechecker/variable_resolution.go +++ b/typechecker/variable_resolution.go @@ -69,6 +69,15 @@ func (s *Scope) SetUniq(name string) string { func VarResolve(p *ast.Program) (map[string]Scope, error) { functionToScope := make(map[string]Scope) + functions := Scope{Variables: make(map[string]Var)} + + for _, d := range p.Declarations { + switch d := d.(type) { + case *ast.FunctionDeclaration: + functions.Set(d.Name, d.Name) + default: + } + } for _, d := range p.Declarations { switch d := d.(type) { @@ -78,8 +87,7 @@ func VarResolve(p *ast.Program) (map[string]Scope, error) { return functionToScope, errorf(d.Token, "duplicate function name %q", d.Name) } - s := Scope{Variables: make(map[string]Var)} - s.Set(d.Name, d.Name) + s := copyScope(&functions) for i, param := range d.Parameters { uniq := s.SetUniq(param.Name) d.Parameters[i].Name = uniq