Commit Diff


commit - 862cec691012ceb800f90f956da2b62c8811e4f6
commit + 11a12b204a216d8d8e25fdac1103c8beef3027e3
blob - 4b49dea9826b681b7046ab844cfe218f7c151aec
blob + a9c2fdf11d12d79210ad103cf28523fcf11a0dd5
--- src/compiler/mod.rs
+++ src/compiler/mod.rs
@@ -18,115 +18,632 @@
 pub mod opcode;
 
 use crate::error::OlangError;
-use crate::parser::ast::{BinOp, Expr, Program, Stmt, UnaryOp};
+use crate::parser::ast::{
+    BinOp, Expr, FnDecl, Program, Stmt, Type,
+    UnaryOp,
+};
 use crate::span::Span;
 use crate::vm::value::Value;
 use opcode::{Chunk, OpCode};
 
+pub struct CompileResult {
+    pub functions: Vec<Chunk>,
+    pub main_index: usize,
+}
+
+#[derive(Debug, Clone, PartialEq)]
+enum ValType {
+    Int,
+    Float,
+    Str,
+    Bool,
+    Unknown,
+}
+
+impl ValType {
+    fn from_ast_type(ty: &Type) -> Self {
+        match ty {
+            Type::Int => ValType::Int,
+            Type::Float => ValType::Float,
+            Type::Str => ValType::Str,
+            Type::Bool => ValType::Bool,
+        }
+    }
+
+    fn name(&self) -> &str {
+        match self {
+            ValType::Int => "int",
+            ValType::Float => "float",
+            ValType::Str => "str",
+            ValType::Bool => "bool",
+            ValType::Unknown => "unknown",
+        }
+    }
+}
+
 struct Local {
     name: String,
     mutable: bool,
+    ty: ValType,
 }
 
+struct FnInfo {
+    name: String,
+    index: usize,
+}
+
 pub struct Compiler {
     chunk: Chunk,
     locals: Vec<Local>,
+    fn_table: Vec<FnInfo>,
 }
 
 impl Compiler {
     pub fn new() -> Self {
-        Self { chunk: Chunk::new(), locals: Vec::new() }
+        Self {
+            chunk: Chunk::new("", 0),
+            locals: Vec::new(),
+            fn_table: Vec::new(),
+        }
     }
 
-    fn resolve_local(&self, name: &str) -> Option<usize> {
-        self.locals.iter().rposition(|l| l.name == name)
+    fn resolve_local(
+        &self,
+        name: &str,
+    ) -> Option<usize> {
+        self.locals
+            .iter()
+            .rposition(|l| l.name == name)
     }
 
+    fn resolve_fn(
+        &self,
+        name: &str,
+    ) -> Option<usize> {
+        self.fn_table
+            .iter()
+            .find(|f| f.name == name)
+            .map(|f| f.index)
+    }
+
     pub fn compile(
-        mut self, program: &Program,
+        mut self,
+        program: &Program,
+    ) -> Result<CompileResult, OlangError> {
+        // First pass: register all functions
+        for (i, func) in
+            program.functions.iter().enumerate()
+        {
+            self.fn_table.push(FnInfo {
+                name: func.name.clone(),
+                index: i,
+            });
+        }
+
+        // Find main
+        let main_index = self
+            .resolve_fn("main")
+            .ok_or_else(|| {
+                OlangError::new(
+                    "no main() function found",
+                    Span::new(0, 0),
+                )
+            })?;
+
+        // Second pass: compile each function
+        let mut chunks = Vec::new();
+        for func in &program.functions {
+            let chunk =
+                self.compile_fn(func)?;
+            chunks.push(chunk);
+        }
+
+        Ok(CompileResult {
+            functions: chunks,
+            main_index,
+        })
+    }
+
+    fn compile_fn(
+        &mut self,
+        func: &FnDecl,
     ) -> Result<Chunk, OlangError> {
-        let main_fn = program.functions.iter()
-            .find(|f| f.name == "main")
-            .ok_or_else(|| OlangError::new(
-                "no main() function found", Span::new(0, 0),
-            ))?;
-        for stmt in &main_fn.body {
+        // Save and reset state
+        let prev_chunk = std::mem::replace(
+            &mut self.chunk,
+            Chunk::new(
+                &func.name,
+                func.params.len() as u8,
+            ),
+        );
+        let prev_locals =
+            std::mem::take(&mut self.locals);
+
+        // Parameters become the first locals
+        for param in &func.params {
+            self.locals.push(Local {
+                name: param.name.clone(),
+                mutable: false,
+                ty: ValType::from_ast_type(&param.ty),
+            });
+        }
+
+        // Compile body
+        let body = &func.body;
+        if body.is_empty() {
+            // Empty function: return nothing
+            self.chunk.emit_constant(
+                Value::Int(0),
+                0,
+            );
+            self.chunk.emit_op(OpCode::Return, 0);
+        } else {
+            // Compile all but last statement
+            for stmt in &body[..body.len() - 1] {
+                self.compile_stmt(stmt)?;
+            }
+
+            // Last statement: check for implicit
+            // return (ExprStmt without ;)
+            let last = &body[body.len() - 1];
+            self.compile_last_stmt(last)?;
+        }
+
+        // Restore state
+        let chunk = std::mem::replace(
+            &mut self.chunk,
+            prev_chunk,
+        );
+        self.locals = prev_locals;
+
+        Ok(chunk)
+    }
+
+    fn compile_last_stmt(
+        &mut self,
+        stmt: &Stmt,
+    ) -> Result<(), OlangError> {
+        match stmt {
+            // ExprStmt as last statement = implicit
+            // return
+            Stmt::ExprStmt { expr, .. } => {
+                self.compile_expr(expr)?;
+                self.chunk.emit_op(
+                    OpCode::Return,
+                    0,
+                );
+            }
+            // Any other statement: compile normally
+            // and emit a default return
+            _ => {
+                self.compile_stmt(stmt)?;
+                self.chunk.emit_constant(
+                    Value::Int(0),
+                    0,
+                );
+                self.chunk.emit_op(
+                    OpCode::Return,
+                    0,
+                );
+            }
+        }
+        Ok(())
+    }
+
+    fn compile_block(
+        &mut self,
+        block: &[Stmt],
+    ) -> Result<(), OlangError> {
+        let locals_before = self.locals.len();
+
+        for stmt in block {
             self.compile_stmt(stmt)?;
         }
-        self.chunk.emit_op(OpCode::Return, 0);
-        Ok(self.chunk)
+
+        let locals_after = self.locals.len();
+        for _ in locals_before..locals_after {
+            self.chunk.emit_op(OpCode::Pop, 0);
+            self.locals.pop();
+        }
+
+        Ok(())
     }
 
-    fn compile_stmt(&mut self, stmt: &Stmt) -> Result<(), OlangError> {
+    fn compile_stmt(
+        &mut self,
+        stmt: &Stmt,
+    ) -> Result<(), OlangError> {
         match stmt {
-            Stmt::Let { name, mutable, value, .. } => {
+            Stmt::Let {
+                name,
+                mutable,
+                ty,
+                value,
+                span,
+            } => {
+                let inferred =
+                    self.infer_type(value);
+
+                // Check type annotation matches
+                if let Some(ann) = ty {
+                    let expected =
+                        ValType::from_ast_type(ann);
+                    if inferred != ValType::Unknown
+                        && inferred != expected
+                    {
+                        return Err(OlangError::new(
+                            format!(
+                                "type mismatch: \
+                                 expected {}, \
+                                 found {}",
+                                expected.name(),
+                                inferred.name(),
+                            ),
+                            Self::expr_span(value),
+                        ));
+                    }
+                }
+
+                let local_ty = if let Some(ann) = ty {
+                    ValType::from_ast_type(ann)
+                } else {
+                    inferred
+                };
+
                 self.compile_expr(value)?;
                 self.locals.push(Local {
-                    name: name.clone(), mutable: *mutable,
+                    name: name.clone(),
+                    mutable: *mutable,
+                    ty: local_ty,
                 });
             }
-            Stmt::Assign { name, value, span } => {
-                let slot = self.resolve_local(name).ok_or_else(|| {
-                    OlangError::new(
-                        format!("undefined variable '{name}'"), *span,
-                    )
-                })?;
+            Stmt::Assign {
+                name,
+                value,
+                span,
+            } => {
+                let slot = self
+                    .resolve_local(name)
+                    .ok_or_else(|| {
+                        OlangError::new(
+                            format!(
+                                "undefined variable \
+                                 '{name}'"
+                            ),
+                            *span,
+                        )
+                    })?;
+
                 if !self.locals[slot].mutable {
                     return Err(OlangError::new(
-                        format!("cannot assign to immutable variable '{name}'"),
+                        format!(
+                            "cannot assign to \
+                             immutable variable \
+                             '{name}'"
+                        ),
                         *span,
                     ));
                 }
+
                 self.compile_expr(value)?;
-                self.chunk.emit_op(OpCode::SetLocal, 0);
+                self.chunk.emit_op(
+                    OpCode::SetLocal,
+                    0,
+                );
                 self.chunk.emit_byte(slot as u8, 0);
             }
+            Stmt::If {
+                condition,
+                then_block,
+                else_block,
+                ..
+            } => {
+                self.compile_expr(condition)?;
+
+                let then_jump = self.chunk.emit_jump(
+                    OpCode::JumpIfFalse,
+                    0,
+                );
+                self.chunk.emit_op(OpCode::Pop, 0);
+
+                self.compile_block(then_block)?;
+
+                if let Some(else_stmts) = else_block {
+                    let else_jump = self
+                        .chunk
+                        .emit_jump(OpCode::Jump, 0);
+
+                    self.chunk.patch_jump(then_jump);
+                    self.chunk.emit_op(OpCode::Pop, 0);
+
+                    self.compile_block(else_stmts)?;
+
+                    self.chunk.patch_jump(else_jump);
+                } else {
+                    self.chunk.patch_jump(then_jump);
+                    self.chunk.emit_op(OpCode::Pop, 0);
+                }
+            }
+            Stmt::While {
+                condition, body, ..
+            } => {
+                let loop_start = self.chunk.code.len();
+
+                self.compile_expr(condition)?;
+
+                let exit_jump = self.chunk.emit_jump(
+                    OpCode::JumpIfFalse,
+                    0,
+                );
+                self.chunk.emit_op(OpCode::Pop, 0);
+
+                self.compile_block(body)?;
+
+                self.chunk.emit_loop(loop_start, 0);
+
+                self.chunk.patch_jump(exit_jump);
+                self.chunk.emit_op(OpCode::Pop, 0);
+            }
+            Stmt::Return { value, .. } => {
+                if let Some(expr) = value {
+                    self.compile_expr(expr)?;
+                } else {
+                    self.chunk.emit_constant(
+                        Value::Int(0),
+                        0,
+                    );
+                }
+                self.chunk.emit_op(
+                    OpCode::Return,
+                    0,
+                );
+            }
             Stmt::Print { args, .. } => {
                 for arg in args {
                     self.compile_expr(arg)?;
                 }
                 self.chunk.emit_op(OpCode::Print, 0);
-                self.chunk.emit_byte(args.len() as u8, 0);
+                self.chunk
+                    .emit_byte(args.len() as u8, 0);
             }
+            Stmt::For {
+                var, iter, body, span,
+            } => {
+                // Desugar: for i in range(n) { body }
+                // into:    let mut i = 0;
+                //          while i < n { body; i=i+1 }
+
+                // Extract n from range(n)
+                let limit = match iter {
+                    Expr::Call {
+                        name, args, ..
+                    } if name == "range"
+                        && args.len() == 1 =>
+                    {
+                        &args[0]
+                    }
+                    _ => {
+                        return Err(OlangError::new(
+                            "for loops only support \
+                             range(n)",
+                            *span,
+                        ));
+                    }
+                };
+
+                let locals_before =
+                    self.locals.len();
+
+                // let mut i = 0
+                self.chunk.emit_constant(
+                    Value::Int(0),
+                    0,
+                );
+                let var_slot = self.locals.len();
+                self.locals.push(Local {
+                    name: var.clone(),
+                    mutable: true,
+                    ty: ValType::Int,
+                });
+
+                // while i < n
+                let loop_start =
+                    self.chunk.code.len();
+
+                self.chunk.emit_op(
+                    OpCode::GetLocal,
+                    0,
+                );
+                self.chunk.emit_byte(
+                    var_slot as u8,
+                    0,
+                );
+                self.compile_expr(limit)?;
+                self.chunk.emit_op(
+                    OpCode::Less,
+                    0,
+                );
+
+                let exit_jump =
+                    self.chunk.emit_jump(
+                        OpCode::JumpIfFalse,
+                        0,
+                    );
+                self.chunk.emit_op(OpCode::Pop, 0);
+
+                // body
+                self.compile_block(body)?;
+
+                // i = i + 1
+                self.chunk.emit_op(
+                    OpCode::GetLocal,
+                    0,
+                );
+                self.chunk.emit_byte(
+                    var_slot as u8,
+                    0,
+                );
+                self.chunk.emit_constant(
+                    Value::Int(1),
+                    0,
+                );
+                self.chunk.emit_op(
+                    OpCode::Add,
+                    0,
+                );
+                self.chunk.emit_op(
+                    OpCode::SetLocal,
+                    0,
+                );
+                self.chunk.emit_byte(
+                    var_slot as u8,
+                    0,
+                );
+
+                // loop back
+                self.chunk.emit_loop(loop_start, 0);
+
+                self.chunk.patch_jump(exit_jump);
+                self.chunk.emit_op(OpCode::Pop, 0);
+
+                // pop the loop variable
+                let locals_after =
+                    self.locals.len();
+                for _ in locals_before..locals_after {
+                    self.chunk.emit_op(
+                        OpCode::Pop,
+                        0,
+                    );
+                    self.locals.pop();
+                }
+            }
             Stmt::ExprStmt { expr, .. } => {
                 self.compile_expr(expr)?;
                 self.chunk.emit_op(OpCode::Pop, 0);
             }
-            _ => {
-                return Err(OlangError::new(
-                    format!("unsupported statement: {:?}", stmt),
-                    Span::new(0, 0),
-                ));
-            }
         }
         Ok(())
     }
 
-    fn compile_expr(&mut self, expr: &Expr) -> Result<(), OlangError> {
+    fn expr_span(expr: &Expr) -> Span {
         match expr {
-            Expr::IntLit(v, _) => self.chunk.emit_constant(Value::Int(*v), 0),
-            Expr::FloatLit(v, _) => self.chunk.emit_constant(Value::Float(*v), 0),
-            Expr::StrLit(s, _) => self.chunk.emit_constant(Value::Str(s.clone()), 0),
-            Expr::BoolLit(v, _) => self.chunk.emit_constant(Value::Bool(*v), 0),
+            Expr::IntLit(_, s) => *s,
+            Expr::FloatLit(_, s) => *s,
+            Expr::StrLit(_, s) => *s,
+            Expr::BoolLit(_, s) => *s,
+            Expr::Ident(_, s) => *s,
+            Expr::Unary { span, .. } => *span,
+            Expr::Binary { span, .. } => *span,
+            Expr::Call { span, .. } => *span,
+        }
+    }
+
+    fn infer_type(&self, expr: &Expr) -> ValType {
+        match expr {
+            Expr::IntLit(..) => ValType::Int,
+            Expr::FloatLit(..) => ValType::Float,
+            Expr::StrLit(..) => ValType::Str,
+            Expr::BoolLit(..) => ValType::Bool,
+            Expr::Ident(name, _) => {
+                if let Some(slot) =
+                    self.resolve_local(name)
+                {
+                    self.locals[slot].ty.clone()
+                } else {
+                    ValType::Unknown
+                }
+            }
+            Expr::Unary { op, expr, .. } => {
+                match op {
+                    UnaryOp::Negate => {
+                        self.infer_type(expr)
+                    }
+                    UnaryOp::Not => ValType::Bool,
+                }
+            }
+            Expr::Binary { op, left, .. } => {
+                match op {
+                    BinOp::Add
+                    | BinOp::Sub
+                    | BinOp::Mul
+                    | BinOp::Div
+                    | BinOp::Mod => {
+                        self.infer_type(left)
+                    }
+                    _ => ValType::Bool,
+                }
+            }
+            Expr::Call { .. } => ValType::Unknown,
+        }
+    }
+
+    fn compile_expr(
+        &mut self,
+        expr: &Expr,
+    ) -> Result<(), OlangError> {
+        match expr {
+            Expr::IntLit(v, _) => {
+                self.chunk.emit_constant(
+                    Value::Int(*v),
+                    0,
+                );
+            }
+            Expr::FloatLit(v, _) => {
+                self.chunk.emit_constant(
+                    Value::Float(*v),
+                    0,
+                );
+            }
+            Expr::StrLit(s, _) => {
+                self.chunk.emit_constant(
+                    Value::Str(s.clone()),
+                    0,
+                );
+            }
+            Expr::BoolLit(v, _) => {
+                self.chunk.emit_constant(
+                    Value::Bool(*v),
+                    0,
+                );
+            }
             Expr::Ident(name, span) => {
-                let slot = self.resolve_local(name).ok_or_else(|| {
-                    OlangError::new(format!("undefined variable '{name}'"), *span)
-                })?;
-                self.chunk.emit_op(OpCode::GetLocal, 0);
+                let slot = self
+                    .resolve_local(name)
+                    .ok_or_else(|| {
+                        OlangError::new(
+                            format!(
+                                "undefined variable \
+                                 '{name}'"
+                            ),
+                            *span,
+                        )
+                    })?;
+                self.chunk.emit_op(
+                    OpCode::GetLocal,
+                    0,
+                );
                 self.chunk.emit_byte(slot as u8, 0);
             }
             Expr::Unary { op, expr, .. } => {
                 self.compile_expr(expr)?;
                 match op {
-                    UnaryOp::Negate => self.chunk.emit_op(OpCode::Negate, 0),
+                    UnaryOp::Negate => {
+                        self.chunk.emit_op(
+                            OpCode::Negate,
+                            0,
+                        );
+                    }
                     UnaryOp::Not => {
-                        return Err(OlangError::new(
-                            "not operator not yet supported", Span::new(0, 0),
-                        ));
+                        self.chunk.emit_op(
+                            OpCode::Not,
+                            0,
+                        );
                     }
                 }
             }
-            Expr::Binary { op, left, right, .. } => {
+            Expr::Binary {
+                op, left, right, ..
+            } => {
                 self.compile_expr(left)?;
                 self.compile_expr(right)?;
                 let opcode = match op {
@@ -135,18 +652,50 @@ impl Compiler {
                     BinOp::Mul => OpCode::Mul,
                     BinOp::Div => OpCode::Div,
                     BinOp::Mod => OpCode::Mod,
-                    _ => return Err(OlangError::new(
-                        format!("unsupported binary op: {:?}", op),
-                        Span::new(0, 0),
-                    )),
+                    BinOp::Eq => OpCode::Equal,
+                    BinOp::NotEq => OpCode::NotEqual,
+                    BinOp::Less => OpCode::Less,
+                    BinOp::LessEq => OpCode::LessEqual,
+                    BinOp::Greater => OpCode::Greater,
+                    BinOp::GreaterEq => {
+                        OpCode::GreaterEqual
+                    }
                 };
                 self.chunk.emit_op(opcode, 0);
             }
-            Expr::Call { name, args: _, span } => {
-                return Err(OlangError::new(
-                    format!("function calls not yet supported: {name}"),
-                    *span,
-                ));
+            Expr::Call {
+                name, args, span,
+            } => {
+                // Push arguments onto stack
+                for arg in args {
+                    self.compile_expr(arg)?;
+                }
+
+                // Resolve function index
+                let fn_index = self
+                    .resolve_fn(name)
+                    .ok_or_else(|| {
+                        OlangError::new(
+                            format!(
+                                "undefined function \
+                                 '{name}'"
+                            ),
+                            *span,
+                        )
+                    })?;
+
+                self.chunk.emit_op(
+                    OpCode::Call,
+                    0,
+                );
+                self.chunk.emit_byte(
+                    fn_index as u8,
+                    0,
+                );
+                self.chunk.emit_byte(
+                    args.len() as u8,
+                    0,
+                );
             }
         }
         Ok(())