diff --git a/.gitignore b/.gitignore index 56598e97d..d3f4c1d31 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,9 @@ Cargo.lock # Gradle .idea/**/gradle.xml .idea/**/libraries + +# Build artifacts +*.dll +*.dll.lib +*.so +*.dylib diff --git a/azure-pipelines.yml b/azure-pipelines.yml index d2c18e0c6..43a1d1ee6 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -25,6 +25,10 @@ stages: parameters: name: rustfmt displayName: cargo fmt + - template: ci/azure-clippy.yml + parameters: + name: clippy + displayName: cargo clippy - stage: coverage displayName: Code Coverage @@ -32,4 +36,4 @@ stages: jobs: - template: ci/azure-coverage.yml parameters: - codecov_token: $(CODECOV_TOKEN) \ No newline at end of file + codecov_token: $(CODECOV_TOKEN) diff --git a/ci/azure-clippy.yml b/ci/azure-clippy.yml new file mode 100644 index 000000000..b39d684de --- /dev/null +++ b/ci/azure-clippy.yml @@ -0,0 +1,22 @@ +parameters: + rust_version: stable + +jobs: +- job: ${{ parameters.name }} + displayName: ${{ parameters.displayName }} + pool: + vmImage: ubuntu-18.04 + steps: + - checkout: self + submodules: true + + - template: azure-install-rust.yml + + - script: | + rustup component add clippy + cargo fmt --version + displayName: Install clippy + + - script: | + cargo fmt --all + displayName: Check clippy diff --git a/crates/mun/Cargo.toml b/crates/mun/Cargo.toml index 7b7df6af7..2a198436b 100644 --- a/crates/mun/Cargo.toml +++ b/crates/mun/Cargo.toml @@ -10,3 +10,8 @@ clap = "2.33.0" mun_compiler = { path = "../mun_compiler" } mun_compiler_daemon = { path = "../mun_compiler_daemon" } mun_runtime = { path = "../mun_runtime" } + +[dev-dependencies.cargo-husky] +version = "1" +default-features = false # Disable features which are enabled by default +features = ["precommit-hook", "run-cargo-test", "run-cargo-fmt", "run-cargo-clippy", "run-for-all"] diff --git a/crates/mun/main.dll b/crates/mun/main.dll deleted file mode 100644 index b120b125a..000000000 Binary files a/crates/mun/main.dll and /dev/null differ diff --git a/crates/mun/main.dll.lib b/crates/mun/main.dll.lib deleted file mode 100644 index cf5abbe06..000000000 Binary files a/crates/mun/main.dll.lib and /dev/null differ diff --git a/crates/mun_abi/src/autogen_impl.rs b/crates/mun_abi/src/autogen_impl.rs index 7d35de78d..a0ad13f5e 100644 --- a/crates/mun_abi/src/autogen_impl.rs +++ b/crates/mun_abi/src/autogen_impl.rs @@ -148,13 +148,13 @@ impl DispatchTable { /// This is generally not recommended, use with caution! Calling this method with an /// out-of-bounds index is _undefined behavior_ even if the resulting reference is not used. /// For a safe alternative see [get_ptr_mut](#method.get_ptr_mut). - pub unsafe fn get_ptr_unchecked_mut(&self, idx: u32) -> &mut *const c_void { + pub unsafe fn get_ptr_unchecked_mut(&mut self, idx: u32) -> &mut *const c_void { &mut *self.fn_ptrs.offset(idx as isize) } /// Returns a mutable reference to a function pointer at the given index, or `None` if out of /// bounds. - pub fn get_ptr_mut(&self, idx: u32) -> Option<&mut *const c_void> { + pub fn get_ptr_mut(&mut self, idx: u32) -> Option<&mut *const c_void> { if idx < self.num_entries { Some(unsafe { self.get_ptr_unchecked_mut(idx) }) } else { @@ -502,7 +502,7 @@ mod tests { let signatures = &[fn_signature]; let fn_ptrs = &mut [ptr::null()]; - let dispatch_table = fake_dispatch_table(signatures, fn_ptrs); + let mut dispatch_table = fake_dispatch_table(signatures, fn_ptrs); assert_eq!( unsafe { dispatch_table.get_ptr_unchecked_mut(0) }, &mut fn_ptrs[0] @@ -521,7 +521,7 @@ mod tests { let signatures = &[fn_signature]; let fn_ptrs = &mut [ptr::null()]; - let dispatch_table = fake_dispatch_table(signatures, fn_ptrs); + let mut dispatch_table = fake_dispatch_table(signatures, fn_ptrs); assert_eq!(dispatch_table.get_ptr_mut(1), None); } @@ -537,7 +537,7 @@ mod tests { let signatures = &[fn_signature]; let fn_ptrs = &mut [ptr::null()]; - let dispatch_table = fake_dispatch_table(signatures, fn_ptrs); + let mut dispatch_table = fake_dispatch_table(signatures, fn_ptrs); assert_eq!(dispatch_table.get_ptr_mut(0), Some(&mut fn_ptrs[0])); } diff --git a/crates/mun_codegen/src/db.rs b/crates/mun_codegen/src/db.rs index 086e0ad12..01561b8de 100644 --- a/crates/mun_codegen/src/db.rs +++ b/crates/mun_codegen/src/db.rs @@ -1,3 +1,5 @@ +#![allow(clippy::type_repetition_in_bounds)] + use mun_hir as hir; use crate::{code_gen::symbols::TypeInfo, ir::module::ModuleIR, Context}; diff --git a/crates/mun_codegen/src/ir.rs b/crates/mun_codegen/src/ir.rs index aed7a6cdc..807768dba 100644 --- a/crates/mun_codegen/src/ir.rs +++ b/crates/mun_codegen/src/ir.rs @@ -1,5 +1,6 @@ use inkwell::types::{AnyTypeEnum, BasicTypeEnum}; +pub mod body; pub(crate) mod dispatch_table; pub mod function; pub mod module; diff --git a/crates/mun_codegen/src/ir/body.rs b/crates/mun_codegen/src/ir/body.rs new file mode 100644 index 000000000..69762739c --- /dev/null +++ b/crates/mun_codegen/src/ir/body.rs @@ -0,0 +1,471 @@ +use crate::{ir::dispatch_table::DispatchTable, ir::try_convert_any_to_basic, IrDatabase}; +use inkwell::{ + builder::Builder, + module::Module, + values::{BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, IntValue}, + FloatPredicate, IntPredicate, +}; +use mun_hir::{ + self as hir, ArithOp, BinaryOp, Body, CmpOp, Expr, ExprId, HirDisplay, InferenceResult, + Literal, Ordering, Pat, PatId, Path, Resolution, Resolver, Statement, TypeCtor, +}; +use std::{collections::HashMap, mem, sync::Arc}; + +mod name; +use name::OptName; + +pub(crate) struct BodyIrGenerator<'a, 'b, D: IrDatabase> { + db: &'a D, + module: &'a Module, + body: Arc, + infer: Arc, + builder: Builder, + fn_value: FunctionValue, + pat_to_param: HashMap, + pat_to_local: HashMap, + pat_to_name: HashMap, + function_map: &'a HashMap, + dispatch_table: &'b DispatchTable, +} + +impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { + pub fn new( + db: &'a D, + module: &'a Module, + hir_function: hir::Function, + ir_function: FunctionValue, + function_map: &'a HashMap, + dispatch_table: &'b DispatchTable, + ) -> Self { + // Get the type information from the `hir::Function` + let body = hir_function.body(db); + let infer = hir_function.infer(db); + + // Construct a builder for the IR function + let context = module.get_context(); + let builder = context.create_builder(); + let body_ir = context.append_basic_block(&ir_function, "body"); + builder.position_at_end(&body_ir); + + BodyIrGenerator { + db, + module, + body, + infer, + builder, + fn_value: ir_function, + pat_to_param: HashMap::default(), + pat_to_local: HashMap::default(), + pat_to_name: HashMap::default(), + function_map, + dispatch_table, + } + } + + /// Generates IR for the body of the function. + pub fn gen_fn_body(&mut self) { + // Iterate over all parameters and their type and store them so we can reference them + // later in code. + for (i, (pat, _ty)) in self.body.params().iter().enumerate() { + let body = self.body.clone(); // Avoid borrow issues + + match &body[*pat] { + Pat::Bind { name } => { + let name = name.to_string(); + let param = self.fn_value.get_nth_param(i as u32).unwrap(); + param.set_name(&name); // Assign a name to the IR value consistent with the code. + self.pat_to_param.insert(*pat, param); + self.pat_to_name.insert(*pat, name); + } + Pat::Wild => { + // Wildcard patterns cannot be referenced from code. So nothing to do. + } + Pat::Path(_) => unreachable!( + "Path patterns are not supported as parameters, are we missing a diagnostic?" + ), + Pat::Missing => unreachable!( + "found missing Pattern, should not be generating IR for incomplete code" + ), + } + } + + // Generate code for the body of the function + let ret_value = self.gen_expr(self.body.body_expr()); + + // Construct a return statement from the returned value of the body + if let Some(value) = ret_value { + self.builder.build_return(Some(&value)); + } else { + self.builder.build_return(None); + } + } + + /// Generates IR for the specified expression. Dependending on the type of expression an IR + /// value is returned. + fn gen_expr(&mut self, expr: ExprId) -> Option { + let body = self.body.clone(); + match &body[expr] { + Expr::Block { + ref statements, + tail, + } => self.gen_block(expr, statements, *tail), + Expr::Path(ref p) => { + let resolver = mun_hir::resolver_for_expr(self.body.clone(), self.db, expr); + Some(self.gen_path_expr(p, expr, &resolver)) + } + Expr::Literal(lit) => Some(self.gen_literal(lit)), + Expr::BinaryOp { lhs, rhs, op } => { + Some(self.gen_binary_op(expr, *lhs, *rhs, op.expect("missing op"))) + } + Expr::Call { + ref callee, + ref args, + } => self.gen_call(*callee, &args).try_as_basic_value().left(), + Expr::If { + condition, + then_branch, + else_branch, + } => self.gen_if(expr, *condition, *then_branch, *else_branch), + _ => unimplemented!("unimplemented expr type {:?}", &body[expr]), + } + } + + /// Generates an IR value that represents the given `Literal`. + fn gen_literal(&mut self, lit: &Literal) -> BasicValueEnum { + match lit { + Literal::Int(v) => self + .module + .get_context() + .i64_type() + .const_int(unsafe { mem::transmute::(*v) }, true) + .into(), + + Literal::Float(v) => self + .module + .get_context() + .f64_type() + .const_float(*v as f64) + .into(), + + Literal::Bool(value) => self + .module + .get_context() + .bool_type() + .const_int(if *value { 1 } else { 0 }, false) + .into(), + + Literal::String(_) => unimplemented!("string literals are not implemented yet"), + } + } + + /// Generates IR for the specified block expression. + fn gen_block( + &mut self, + _tgt_expr: ExprId, + statements: &[Statement], + tail: Option, + ) -> Option { + for statement in statements.iter() { + match statement { + Statement::Let { + pat, initializer, .. + } => { + self.gen_let_statement(*pat, *initializer); + } + Statement::Expr(expr) => { + self.gen_expr(*expr); + } + }; + } + tail.and_then(|expr| self.gen_expr(expr)) + } + + /// Constructs a builder that should be used to emit an `alloca` instruction. These instructions + /// should be at the start of the IR. + fn new_alloca_builder(&self) -> Builder { + let temp_builder = Builder::create(); + let block = self + .fn_value + .get_first_basic_block() + .expect("at this stage there must be a block"); + if let Some(first_instruction) = block.get_first_instruction() { + temp_builder.position_before(&first_instruction); + } else { + temp_builder.position_at_end(&block); + } + temp_builder + } + + /// Generate IR for a let statement: `let a:int = 3` + fn gen_let_statement(&mut self, pat: PatId, initializer: Option) { + let initializer = initializer.and_then(|expr| self.gen_expr(expr)); + + match &self.body[pat] { + Pat::Bind { name } => { + let builder = self.new_alloca_builder(); + let pat_ty = self.infer[pat].clone(); + let ty = try_convert_any_to_basic(self.db.type_ir(pat_ty.clone())) + .expect("expected basic type"); + let ptr = builder.build_alloca(ty, &name.to_string()); + self.pat_to_local.insert(pat, ptr); + self.pat_to_name.insert(pat, name.to_string()); + if !(pat_ty.is_empty() || pat_ty.is_never()) { + if let Some(value) = initializer { + self.builder.build_store(ptr, value); + }; + } + } + Pat::Wild => {} + Pat::Missing | Pat::Path(_) => unreachable!(), + } + } + + /// Generates IR for looking up a certain path expression. + fn gen_path_expr( + &self, + path: &Path, + _expr: ExprId, + resolver: &Resolver, + ) -> inkwell::values::BasicValueEnum { + let resolution = resolver + .resolve_path_without_assoc_items(self.db, path) + .take_values() + .expect("unknown path"); + + match resolution { + Resolution::LocalBinding(pat) => { + if let Some(param) = self.pat_to_param.get(&pat) { + *param + } else if let Some(ptr) = self.pat_to_local.get(&pat) { + let name = self.pat_to_name.get(&pat).expect("could not find pat name"); + self.builder.build_load(*ptr, &name) + } else { + unreachable!("could not find the pattern.."); + } + } + Resolution::Def(_) => panic!("no support for module definitions"), + } + } + + /// Generates IR to calculate a binary operation between two expressions. + fn gen_binary_op( + &mut self, + _tgt_expr: ExprId, + lhs: ExprId, + rhs: ExprId, + op: BinaryOp, + ) -> BasicValueEnum { + let lhs_value = self.gen_expr(lhs).expect("no lhs value"); + let rhs_value = self.gen_expr(rhs).expect("no rhs value"); + let lhs_type = self.infer[lhs].clone(); + let rhs_type = self.infer[rhs].clone(); + + match lhs_type.as_simple() { + Some(TypeCtor::Float) => self.gen_binary_op_float( + *lhs_value.as_float_value(), + *rhs_value.as_float_value(), + op, + ), + Some(TypeCtor::Int) => { + self.gen_binary_op_int(*lhs_value.as_int_value(), *rhs_value.as_int_value(), op) + } + _ => unimplemented!( + "unimplemented operation {0}op{1}", + lhs_type.display(self.db), + rhs_type.display(self.db) + ), + } + } + + /// Generates IR to calculate a binary operation between two floating point values. + fn gen_binary_op_float( + &mut self, + lhs: FloatValue, + rhs: FloatValue, + op: BinaryOp, + ) -> BasicValueEnum { + match op { + BinaryOp::ArithOp(ArithOp::Add) => self.builder.build_float_add(lhs, rhs, "add").into(), + BinaryOp::ArithOp(ArithOp::Subtract) => { + self.builder.build_float_sub(lhs, rhs, "sub").into() + } + BinaryOp::ArithOp(ArithOp::Divide) => { + self.builder.build_float_div(lhs, rhs, "div").into() + } + BinaryOp::ArithOp(ArithOp::Multiply) => { + self.builder.build_float_mul(lhs, rhs, "mul").into() + } + BinaryOp::CmpOp(op) => { + let (name, predicate) = match op { + CmpOp::Eq { negated: false } => ("eq", FloatPredicate::OEQ), + CmpOp::Eq { negated: true } => ("neq", FloatPredicate::ONE), + CmpOp::Ord { + ordering: Ordering::Less, + strict: false, + } => ("lesseq", FloatPredicate::OLE), + CmpOp::Ord { + ordering: Ordering::Less, + strict: true, + } => ("less", FloatPredicate::OLT), + CmpOp::Ord { + ordering: Ordering::Greater, + strict: false, + } => ("greatereq", FloatPredicate::OGE), + CmpOp::Ord { + ordering: Ordering::Greater, + strict: true, + } => ("greater", FloatPredicate::OGT), + }; + self.builder + .build_float_compare(predicate, lhs, rhs, name) + .into() + } + _ => unimplemented!("Operator {:?} is not implemented for float", op), + } + } + + /// Generates IR to calculate a binary operation between two integer values. + fn gen_binary_op_int(&mut self, lhs: IntValue, rhs: IntValue, op: BinaryOp) -> BasicValueEnum { + match op { + BinaryOp::ArithOp(ArithOp::Add) => self.builder.build_int_add(lhs, rhs, "add").into(), + BinaryOp::ArithOp(ArithOp::Subtract) => { + self.builder.build_int_sub(lhs, rhs, "sub").into() + } + BinaryOp::ArithOp(ArithOp::Divide) => { + self.builder.build_int_signed_div(lhs, rhs, "div").into() + } + BinaryOp::ArithOp(ArithOp::Multiply) => { + self.builder.build_int_mul(lhs, rhs, "mul").into() + } + BinaryOp::CmpOp(op) => { + let (name, predicate) = match op { + CmpOp::Eq { negated: false } => ("eq", IntPredicate::EQ), + CmpOp::Eq { negated: true } => ("neq", IntPredicate::NE), + CmpOp::Ord { + ordering: Ordering::Less, + strict: false, + } => ("lesseq", IntPredicate::SLE), + CmpOp::Ord { + ordering: Ordering::Less, + strict: true, + } => ("less", IntPredicate::SLT), + CmpOp::Ord { + ordering: Ordering::Greater, + strict: false, + } => ("greatereq", IntPredicate::SGE), + CmpOp::Ord { + ordering: Ordering::Greater, + strict: true, + } => ("greater", IntPredicate::SGT), + }; + self.builder + .build_int_compare(predicate, lhs, rhs, name) + .into() + } + _ => unreachable!(format!("Operator {:?} is not implemented for integer", op)), + } + } + + // TODO: Implement me! + fn should_use_dispatch_table(&self) -> bool { + true + } + + /// Generates IR for a function call. + fn gen_call(&mut self, callee: ExprId, args: &[ExprId]) -> CallSiteValue { + // Get the function value from the map + let function = self.infer[callee] + .as_function_def() + .expect("expected a function expression"); + + // Get all the arguments + let args: Vec = args + .iter() + .map(|expr| self.gen_expr(*expr).expect("expected a value")) + .collect(); + + if self.should_use_dispatch_table() { + let ptr_value = + self.dispatch_table + .gen_function_lookup(self.db, &self.builder, function); + self.builder + .build_call(ptr_value, &args, &function.name(self.db).to_string()) + } else { + let llvm_function = self + .function_map + .get(&function) + .expect("missing function value for hir function"); + self.builder + .build_call(*llvm_function, &args, &function.name(self.db).to_string()) + } + } + + /// Generates IR for an if statement. + fn gen_if( + &mut self, + _expr: ExprId, + condition: ExprId, + then_branch: ExprId, + else_branch: Option, + ) -> Option { + // Generate IR for the condition + let condition_ir = self + .gen_expr(condition) + .expect("condition must have a value") + .into_int_value(); + + // Generate the code blocks to branch to + let context = self.module.get_context(); + let mut then_block = context.append_basic_block(&self.fn_value, "then"); + let else_block_and_expr = match &else_branch { + Some(else_branch) => Some(( + context.append_basic_block(&self.fn_value, "else"), + else_branch, + )), + None => None, + }; + let merge_block = context.append_basic_block(&self.fn_value, "if_merge"); + + // Build the actual branching IR for the if statement + let else_block = else_block_and_expr + .as_ref() + .map(|e| &e.0) + .unwrap_or(&merge_block); + self.builder + .build_conditional_branch(condition_ir, &then_block, else_block); + + // Fill the then block + self.builder.position_at_end(&then_block); + let then_block_ir = self.gen_expr(then_branch); + self.builder.build_unconditional_branch(&merge_block); + then_block = self.builder.get_insert_block().unwrap(); + + // Fill the else block, if it exists and get the result back + let else_ir_and_block = if let Some((else_block, else_branch)) = else_block_and_expr { + else_block.move_after(&then_block); + self.builder.position_at_end(&else_block); + let result_ir = self.gen_expr(*else_branch); + self.builder.build_unconditional_branch(&merge_block); + result_ir.map(|res| (res, self.builder.get_insert_block().unwrap())) + } else { + None + }; + + // Create merge block + merge_block.move_after(&self.builder.get_insert_block().unwrap()); + self.builder.position_at_end(&merge_block); + + // Construct phi block if a value was returned + if let Some(then_block_ir) = then_block_ir { + if let Some((else_block_ir, else_block)) = else_ir_and_block { + let phi = self.builder.build_phi(then_block_ir.get_type(), "iftmp"); + phi.add_incoming(&[(&then_block_ir, &then_block), (&else_block_ir, &else_block)]); + Some(phi.as_basic_value()) + } else { + Some(then_block_ir) + } + } else { + None + } + } +} diff --git a/crates/mun_codegen/src/ir/body/name.rs b/crates/mun_codegen/src/ir/body/name.rs new file mode 100644 index 000000000..c8a29aafb --- /dev/null +++ b/crates/mun_codegen/src/ir/body/name.rs @@ -0,0 +1,30 @@ +use inkwell::values::BasicValueEnum; + +pub(crate) trait OptName { + fn get_name(&self) -> Option<&str>; + fn set_name>(&self, name: T); +} + +impl OptName for BasicValueEnum { + fn get_name(&self) -> Option<&str> { + match self { + BasicValueEnum::ArrayValue(v) => v.get_name().to_str().ok(), + BasicValueEnum::IntValue(v) => v.get_name().to_str().ok(), + BasicValueEnum::FloatValue(v) => v.get_name().to_str().ok(), + BasicValueEnum::PointerValue(v) => v.get_name().to_str().ok(), + BasicValueEnum::StructValue(v) => v.get_name().to_str().ok(), + BasicValueEnum::VectorValue(v) => v.get_name().to_str().ok(), + } + } + + fn set_name>(&self, name: T) { + match self { + BasicValueEnum::ArrayValue(v) => v.set_name(name.as_ref()), + BasicValueEnum::IntValue(v) => v.set_name(name.as_ref()), + BasicValueEnum::FloatValue(v) => v.set_name(name.as_ref()), + BasicValueEnum::PointerValue(v) => v.set_name(name.as_ref()), + BasicValueEnum::StructValue(v) => v.set_name(name.as_ref()), + BasicValueEnum::VectorValue(v) => v.set_name(name.as_ref()), + }; + } +} diff --git a/crates/mun_codegen/src/ir/function.rs b/crates/mun_codegen/src/ir/function.rs index 440095c1c..1ccaa257b 100644 --- a/crates/mun_codegen/src/ir/function.rs +++ b/crates/mun_codegen/src/ir/function.rs @@ -1,20 +1,11 @@ -use super::try_convert_any_to_basic; +use crate::ir::body::BodyIrGenerator; use crate::ir::dispatch_table::DispatchTable; -use crate::values::{ - BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, InstructionOpcode, IntValue, -}; +use crate::values::FunctionValue; use crate::{IrDatabase, Module, OptimizationLevel}; -use inkwell::builder::Builder; use inkwell::passes::{PassManager, PassManagerBuilder}; -use inkwell::types::{AnyTypeEnum, BasicTypeEnum}; -use inkwell::{FloatPredicate, IntPredicate}; -use mun_hir::{ - self as hir, ArithOp, BinaryOp, Body, CmpOp, Expr, ExprId, HirDisplay, InferenceResult, - Literal, Ordering, Pat, PatId, Path, Resolution, Resolver, Statement, TypeCtor, -}; +use inkwell::types::AnyTypeEnum; +use mun_hir as hir; use std::collections::HashMap; -use std::mem; -use std::sync::Arc; /// Constructs a PassManager to optimize functions for the given optimization level. pub(crate) fn create_pass_manager( @@ -57,18 +48,12 @@ pub(crate) fn gen_body<'a, 'b, D: IrDatabase>( llvm_functions: &'a HashMap, dispatch_table: &'b DispatchTable, ) -> FunctionValue { - let context = db.context(); - let builder = context.create_builder(); - let body_ir = context.append_basic_block(&llvm_function, "body"); - builder.position_at_end(&body_ir); - let mut code_gen = BodyIrGenerator::new( db, module, hir_function, llvm_function, llvm_functions, - builder, dispatch_table, ); @@ -76,393 +61,3 @@ pub(crate) fn gen_body<'a, 'b, D: IrDatabase>( llvm_function } - -struct BodyIrGenerator<'a, 'b, D: IrDatabase> { - db: &'a D, - module: &'a Module, - body: Arc, - infer: Arc, - builder: Builder, - fn_value: FunctionValue, - pat_to_param: HashMap, - pat_to_local: HashMap, - pat_to_name: HashMap, - function_map: &'a HashMap, - dispatch_table: &'b DispatchTable, -} - -impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { - fn new( - db: &'a D, - module: &'a Module, - f: hir::Function, - fn_value: FunctionValue, - function_map: &'a HashMap, - builder: Builder, - dispatch_table: &'b DispatchTable, - ) -> Self { - let body = f.body(db); - let infer = f.infer(db); - - BodyIrGenerator { - db, - module, - body, - infer, - builder, - fn_value, - pat_to_param: HashMap::default(), - pat_to_local: HashMap::default(), - pat_to_name: HashMap::default(), - function_map, - dispatch_table, - } - } - - fn gen_fn_body(&mut self) { - for (i, (pat, _ty)) in self.body.params().iter().enumerate() { - let body = self.body.clone(); // Avoid borrow issues - match &body[*pat] { - Pat::Bind { name } => { - let name = name.to_string(); - let param = self.fn_value.get_nth_param(i as u32).unwrap(); - param.set_name(&name); - self.pat_to_param.insert(*pat, param); - self.pat_to_name.insert(*pat, name); - } - Pat::Wild => {} - Pat::Missing | Pat::Path(_) => unreachable!(), - } - } - - let ret_value = self.gen_expr(self.body.body_expr()); - if let Some(value) = ret_value { - self.builder.build_return(Some(&value)); - } else { - self.builder.build_return(None); - } - } - - fn gen_expr(&mut self, expr: ExprId) -> Option { - let body = self.body.clone(); - let mut value = match &body[expr] { - &Expr::Block { - ref statements, - tail, - } => { - for statement in statements.iter() { - match statement { - Statement::Let { - pat, initializer, .. - } => { - self.gen_let_statement(*pat, *initializer); - } - Statement::Expr(expr) => { - self.gen_expr(*expr); - } - }; - } - tail.and_then(|expr| self.gen_expr(expr)) - } - Expr::Path(ref p) => { - let resolver = mun_hir::resolver_for_expr(self.body.clone(), self.db, expr); - Some(self.gen_path_expr(p, expr, &resolver)) - } - Expr::Literal(lit) => match lit { - Literal::Int(v) => Some( - self.module - .get_context() - .i64_type() - .const_int(unsafe { mem::transmute::(*v) }, true) - .into(), - ), - Literal::Float(v) => Some( - self.module - .get_context() - .f64_type() - .const_float(*v as f64) - .into(), - ), - Literal::String(_) | Literal::Bool(_) => unreachable!(), - }, - &Expr::BinaryOp { lhs, rhs, op } => { - Some(self.gen_binary_op(lhs, rhs, op.expect("missing op"))) - } - Expr::Call { - ref callee, - ref args, - } => self.gen_call(*callee, &args).try_as_basic_value().left(), - _ => unreachable!("unimplemented expr type"), - }; - - // Check expected type or perform implicit cast - value = value.map(|value| { - match ( - value.get_type(), - try_convert_any_to_basic(self.db.type_ir(self.infer[expr].clone())), - ) { - (BasicTypeEnum::IntType(_), Some(target @ BasicTypeEnum::FloatType(_))) => self - .builder - .build_cast(InstructionOpcode::SIToFP, value, target, "implicit_cast"), - (a, Some(b)) if a == b => value, - _ => unreachable!("could not perform implicit cast"), - } - }); - - value - } - - /// Constructs a builder that should be used to emit an `alloca` instruction. These instructions - /// should be at the start of the IR. - fn new_alloca_builder(&self) -> Builder { - let temp_builder = Builder::create(); - let block = self - .builder - .get_insert_block() - .expect("at this stage there must be a block"); - if let Some(first_instruction) = block.get_first_instruction() { - temp_builder.position_before(&first_instruction); - } else { - temp_builder.position_at_end(&block); - } - temp_builder - } - - /// Generate IR for a let statement: `let a:int = 3` - fn gen_let_statement(&mut self, pat: PatId, initializer: Option) { - let initializer = initializer.and_then(|expr| self.gen_expr(expr)); - - match &self.body[pat] { - Pat::Bind { name } => { - let builder = self.new_alloca_builder(); - let ty = try_convert_any_to_basic(self.db.type_ir(self.infer[pat].clone())) - .expect("expected basic type"); - let ptr = builder.build_alloca(ty, &name.to_string()); - self.pat_to_local.insert(pat, ptr); - self.pat_to_name.insert(pat, name.to_string()); - if let Some(value) = initializer { - self.builder.build_store(ptr, value); - }; - } - Pat::Wild => {} - Pat::Missing | Pat::Path(_) => unreachable!(), - } - } - - fn gen_path_expr( - &self, - path: &Path, - _expr: ExprId, - resolver: &Resolver, - ) -> inkwell::values::BasicValueEnum { - let resolution = resolver - .resolve_path_without_assoc_items(self.db, path) - .take_values() - .expect("unknown path"); - - match resolution { - Resolution::LocalBinding(pat) => { - if let Some(param) = self.pat_to_param.get(&pat) { - *param - } else if let Some(ptr) = self.pat_to_local.get(&pat) { - let name = self.pat_to_name.get(&pat).expect("could not find pat name"); - self.builder.build_load(*ptr, &name) - } else { - unreachable!("could not find the pattern.."); - } - } - Resolution::Def(_) => panic!("no support for module definitions"), - } - } - - fn gen_binary_op(&mut self, lhs: ExprId, rhs: ExprId, op: BinaryOp) -> BasicValueEnum { - let lhs_value = self.gen_expr(lhs).expect("no lhs value"); - let rhs_value = self.gen_expr(rhs).expect("no rhs value"); - let lhs_type = self.infer[lhs].clone(); - let rhs_type = self.infer[rhs].clone(); - - match lhs_type.as_simple() { - Some(TypeCtor::Float) => self.gen_binary_op_float( - *lhs_value.as_float_value(), - *rhs_value.as_float_value(), - op, - ), - Some(TypeCtor::Int) => { - self.gen_binary_op_int(*lhs_value.as_int_value(), *rhs_value.as_int_value(), op) - } - _ => unreachable!( - "Unsupported operation {0}op{1}", - lhs_type.display(self.db), - rhs_type.display(self.db) - ), - } - } - - fn gen_binary_op_float( - &mut self, - lhs: FloatValue, - rhs: FloatValue, - op: BinaryOp, - ) -> BasicValueEnum { - match op { - BinaryOp::ArithOp(ArithOp::Add) => self.builder.build_float_add(lhs, rhs, "add").into(), - BinaryOp::ArithOp(ArithOp::Subtract) => { - self.builder.build_float_sub(lhs, rhs, "sub").into() - } - BinaryOp::ArithOp(ArithOp::Divide) => { - self.builder.build_float_div(lhs, rhs, "div").into() - } - BinaryOp::ArithOp(ArithOp::Multiply) => { - self.builder.build_float_mul(lhs, rhs, "mul").into() - } - // BinaryOp::Remainder => Some(self.gen_remainder(lhs, rhs)), - // BinaryOp::Power =>, - // BinaryOp::Assign, - // BinaryOp::AddAssign, - // BinaryOp::SubtractAssign, - // BinaryOp::DivideAssign, - // BinaryOp::MultiplyAssign, - // BinaryOp::RemainderAssign, - // BinaryOp::PowerAssign, - BinaryOp::CmpOp(op) => { - let (name, predicate) = match op { - CmpOp::Eq { negated: false } => ("eq", FloatPredicate::OEQ), - CmpOp::Eq { negated: true } => ("neq", FloatPredicate::ONE), - CmpOp::Ord { - ordering: Ordering::Less, - strict: false, - } => ("lesseq", FloatPredicate::OLE), - CmpOp::Ord { - ordering: Ordering::Less, - strict: true, - } => ("less", FloatPredicate::OLT), - CmpOp::Ord { - ordering: Ordering::Greater, - strict: false, - } => ("greatereq", FloatPredicate::OGE), - CmpOp::Ord { - ordering: Ordering::Greater, - strict: true, - } => ("greater", FloatPredicate::OGT), - }; - self.builder - .build_float_compare(predicate, lhs, rhs, name) - .into() - } - _ => unreachable!(), - } - } - - fn gen_binary_op_int(&mut self, lhs: IntValue, rhs: IntValue, op: BinaryOp) -> BasicValueEnum { - match op { - BinaryOp::ArithOp(ArithOp::Add) => self.builder.build_int_add(lhs, rhs, "add").into(), - BinaryOp::ArithOp(ArithOp::Subtract) => { - self.builder.build_int_sub(lhs, rhs, "sub").into() - } - BinaryOp::ArithOp(ArithOp::Divide) => { - self.builder.build_int_signed_div(lhs, rhs, "div").into() - } - BinaryOp::ArithOp(ArithOp::Multiply) => { - self.builder.build_int_mul(lhs, rhs, "mul").into() - } - // BinaryOp::Remainder => Some(self.gen_remainder(lhs, rhs)), - // BinaryOp::Power =>, - // BinaryOp::Assign, - // BinaryOp::AddAssign, - // BinaryOp::SubtractAssign, - // BinaryOp::DivideAssign, - // BinaryOp::MultiplyAssign, - // BinaryOp::RemainderAssign, - // BinaryOp::PowerAssign, - BinaryOp::CmpOp(op) => { - let (name, predicate) = match op { - CmpOp::Eq { negated: false } => ("eq", IntPredicate::EQ), - CmpOp::Eq { negated: true } => ("neq", IntPredicate::NE), - CmpOp::Ord { - ordering: Ordering::Less, - strict: false, - } => ("lesseq", IntPredicate::SLE), - CmpOp::Ord { - ordering: Ordering::Less, - strict: true, - } => ("less", IntPredicate::SLT), - CmpOp::Ord { - ordering: Ordering::Greater, - strict: false, - } => ("greatereq", IntPredicate::SGE), - CmpOp::Ord { - ordering: Ordering::Greater, - strict: true, - } => ("greater", IntPredicate::SGT), - }; - self.builder - .build_int_compare(predicate, lhs, rhs, name) - .into() - } - _ => unreachable!(), - } - } - - // TODO: Implement me! - fn should_use_dispatch_table(&self) -> bool { - true - } - - /// Generates IR for a function call. - fn gen_call(&mut self, callee: ExprId, args: &[ExprId]) -> CallSiteValue { - // Get the function value from the map - let function = self.infer[callee] - .as_function_def() - .expect("expected a function expression"); - - // Get all the arguments - let args: Vec = args - .iter() - .map(|expr| self.gen_expr(*expr).expect("expected a value")) - .collect(); - - if self.should_use_dispatch_table() { - let ptr_value = - self.dispatch_table - .gen_function_lookup(self.db, &self.builder, function); - self.builder - .build_call(ptr_value, &args, &function.name(self.db).to_string()) - } else { - let llvm_function = self - .function_map - .get(&function) - .expect("missing function value for hir function"); - self.builder - .build_call(*llvm_function, &args, &function.name(self.db).to_string()) - } - } -} - -trait OptName { - fn get_name(&self) -> Option<&str>; - fn set_name>(&self, name: T); -} - -impl OptName for BasicValueEnum { - fn get_name(&self) -> Option<&str> { - match self { - BasicValueEnum::ArrayValue(v) => v.get_name().to_str().ok(), - BasicValueEnum::IntValue(v) => v.get_name().to_str().ok(), - BasicValueEnum::FloatValue(v) => v.get_name().to_str().ok(), - BasicValueEnum::PointerValue(v) => v.get_name().to_str().ok(), - BasicValueEnum::StructValue(v) => v.get_name().to_str().ok(), - BasicValueEnum::VectorValue(v) => v.get_name().to_str().ok(), - } - } - - fn set_name>(&self, name: T) { - match self { - BasicValueEnum::ArrayValue(v) => v.set_name(name.as_ref()), - BasicValueEnum::IntValue(v) => v.set_name(name.as_ref()), - BasicValueEnum::FloatValue(v) => v.set_name(name.as_ref()), - BasicValueEnum::PointerValue(v) => v.set_name(name.as_ref()), - BasicValueEnum::StructValue(v) => v.set_name(name.as_ref()), - BasicValueEnum::VectorValue(v) => v.set_name(name.as_ref()), - }; - } -} diff --git a/crates/mun_codegen/src/ir/ty.rs b/crates/mun_codegen/src/ir/ty.rs index f038b68b4..d5cab68e5 100644 --- a/crates/mun_codegen/src/ir/ty.rs +++ b/crates/mun_codegen/src/ir/ty.rs @@ -7,7 +7,7 @@ use mun_hir::{ApplicationTy, Ty, TypeCtor}; pub(crate) fn ir_query(db: &impl IrDatabase, ty: Ty) -> AnyTypeEnum { let context = db.context(); match ty { - Ty::Empty => AnyTypeEnum::VoidType(context.void_type()), + Ty::Empty => AnyTypeEnum::StructType(context.struct_type(&[], false)), Ty::Apply(ApplicationTy { ctor, .. }) => match ctor { TypeCtor::Float => AnyTypeEnum::FloatType(context.f64_type()), TypeCtor::Int => AnyTypeEnum::IntType(context.i64_type()), @@ -19,13 +19,17 @@ pub(crate) fn ir_query(db: &impl IrDatabase, ty: Ty) -> AnyTypeEnum { .iter() .map(|p| try_convert_any_to_basic(db.type_ir(p.clone())).unwrap()) .collect(); - let ret_ty = match db.type_ir(ty.ret().clone()) { - AnyTypeEnum::VoidType(v) => return v.fn_type(¶ms, false).into(), - v => try_convert_any_to_basic(v).expect("could not convert return value"), + + let fn_type = match ty.ret() { + Ty::Empty => context.void_type().fn_type(¶ms, false), + ty => try_convert_any_to_basic(db.type_ir(ty.clone())) + .expect("could not convert return value") + .fn_type(¶ms, false), }; - ret_ty.fn_type(¶ms, false).into() + AnyTypeEnum::FunctionType(fn_type) } + _ => unreachable!(), }, _ => unreachable!("unknown type can not be converted"), } diff --git a/crates/mun_codegen/src/snapshots/test__fibonacci.snap b/crates/mun_codegen/src/snapshots/test__fibonacci.snap new file mode 100644 index 000000000..3b3d9de42 --- /dev/null +++ b/crates/mun_codegen/src/snapshots/test__fibonacci.snap @@ -0,0 +1,31 @@ +--- +source: crates/mun_codegen/src/test.rs +expression: "fn fibonacci(n:int):int {\n if n <= 1 {\n n\n } else {\n fibonacci(n-1) + fibonacci(n-2)\n }\n}" +--- +; ModuleID = 'main.mun' +source_filename = "main.mun" + +%DispatchTable = type { i64 (i64)* } + +@dispatchTable = global %DispatchTable { i64 (i64)* @fibonacci } + +define i64 @fibonacci(i64 %n) { +body: + %lesseq = icmp sle i64 %n, 1 + br i1 %lesseq, label %if_merge, label %else + +else: ; preds = %body + %sub = sub i64 %n, 1 + %fibonacci_ptr = load i64 (i64)*, i64 (i64)** getelementptr inbounds (%DispatchTable, %DispatchTable* @dispatchTable, i32 0, i32 0) + %fibonacci = call i64 %fibonacci_ptr(i64 %sub) + %sub1 = sub i64 %n, 2 + %fibonacci_ptr2 = load i64 (i64)*, i64 (i64)** getelementptr inbounds (%DispatchTable, %DispatchTable* @dispatchTable, i32 0, i32 0) + %fibonacci3 = call i64 %fibonacci_ptr2(i64 %sub1) + %add = add i64 %fibonacci, %fibonacci3 + br label %if_merge + +if_merge: ; preds = %body, %else + %iftmp = phi i64 [ %add, %else ], [ %n, %body ] + ret i64 %iftmp +} + diff --git a/crates/mun_codegen/src/snapshots/test__if_statement.snap b/crates/mun_codegen/src/snapshots/test__if_statement.snap new file mode 100644 index 000000000..a65cad628 --- /dev/null +++ b/crates/mun_codegen/src/snapshots/test__if_statement.snap @@ -0,0 +1,28 @@ +--- +source: crates/mun_codegen/src/test.rs +expression: "fn foo(a:int):int {\n let b = if a > 3 {\n let c = if a > 4 {\n a+1\n } else {\n a+3\n }\n c\n } else {\n a-1\n }\n b\n}" +--- +; ModuleID = 'main.mun' +source_filename = "main.mun" + +define i64 @foo(i64 %a) { +body: + %greater = icmp sgt i64 %a, 3 + br i1 %greater, label %then, label %else + +then: ; preds = %body + %greater1 = icmp sgt i64 %a, 4 + %add = add i64 %a, 1 + %add5 = add i64 %a, 3 + %iftmp = select i1 %greater1, i64 %add, i64 %add5 + br label %if_merge + +else: ; preds = %body + %sub = sub i64 %a, 1 + br label %if_merge + +if_merge: ; preds = %else, %then + %iftmp7 = phi i64 [ %iftmp, %then ], [ %sub, %else ] + ret i64 %iftmp7 +} + diff --git a/crates/mun_codegen/src/snapshots/test__void_return.snap b/crates/mun_codegen/src/snapshots/test__void_return.snap new file mode 100644 index 000000000..4acc144bd --- /dev/null +++ b/crates/mun_codegen/src/snapshots/test__void_return.snap @@ -0,0 +1,23 @@ +--- +source: crates/mun_codegen/src/test.rs +expression: "fn bar() {\n let a = 3;\n}\nfn foo(a:int) {\n let c = bar()\n}" +--- +; ModuleID = 'main.mun' +source_filename = "main.mun" + +%DispatchTable = type { void ()* } + +@dispatchTable = global %DispatchTable { void ()* @bar } + +define void @bar() { +body: + ret void +} + +define void @foo(i64 %a) { +body: + %bar_ptr = load void ()*, void ()** getelementptr inbounds (%DispatchTable, %DispatchTable* @dispatchTable, i32 0, i32 0) + call void %bar_ptr() + ret void +} + diff --git a/crates/mun_codegen/src/test.rs b/crates/mun_codegen/src/test.rs index f21e449e1..d33b0a088 100644 --- a/crates/mun_codegen/src/test.rs +++ b/crates/mun_codegen/src/test.rs @@ -6,45 +6,6 @@ use mun_hir::SourceDatabase; use std::cell::RefCell; use std::sync::Arc; -fn test_snapshot(text: &str) { - let text = text.trim().replace("\n ", "\n"); - - let (db, file_id) = MockDatabase::with_single_file(&text); - - let line_index: Arc = db.line_index(file_id); - let messages = RefCell::new(Vec::new()); - let mut sink = DiagnosticSink::new(|diag| { - let line_col = line_index.line_col(diag.highlight_range().start()); - messages.borrow_mut().push(format!( - "error {}:{}: {}", - line_col.line + 1, - line_col.col + 1, - diag.message() - )); - }); - if let Some(module) = Module::package_modules(&db) - .iter() - .find(|m| m.file_id() == file_id) - { - module.diagnostics(&db, &mut sink) - } - drop(sink); - let messages = messages.into_inner(); - - let name = if !messages.is_empty() { - messages.join("\n") - } else { - format!( - "{}", - db.module_ir(file_id) - .llvm_module - .print_to_string() - .to_string() - ) - }; - insta::assert_snapshot!(insta::_macro_support::AutoName, name, &text); -} - #[test] fn function() { test_snapshot( @@ -185,3 +146,92 @@ fn equality_operands() { "#, ); } + +#[test] +fn if_statement() { + test_snapshot( + r#" + fn foo(a:int):int { + let b = if a > 3 { + let c = if a > 4 { + a+1 + } else { + a+3 + } + c + } else { + a-1 + } + b + } + "#, + ) +} + +#[test] +fn void_return() { + test_snapshot( + r#" + fn bar() { + let a = 3; + } + fn foo(a:int) { + let c = bar() + } + "#, + ) +} + +#[test] +fn fibonacci() { + test_snapshot( + r#" + fn fibonacci(n:int):int { + if n <= 1 { + n + } else { + fibonacci(n-1) + fibonacci(n-2) + } + } + "#, + ) +} + +fn test_snapshot(text: &str) { + let text = text.trim().replace("\n ", "\n"); + + let (db, file_id) = MockDatabase::with_single_file(&text); + + let line_index: Arc = db.line_index(file_id); + let messages = RefCell::new(Vec::new()); + let mut sink = DiagnosticSink::new(|diag| { + let line_col = line_index.line_col(diag.highlight_range().start()); + messages.borrow_mut().push(format!( + "error {}:{}: {}", + line_col.line + 1, + line_col.col + 1, + diag.message() + )); + }); + if let Some(module) = Module::package_modules(&db) + .iter() + .find(|m| m.file_id() == file_id) + { + module.diagnostics(&db, &mut sink) + } + drop(sink); + let messages = messages.into_inner(); + + let name = if !messages.is_empty() { + messages.join("\n") + } else { + format!( + "{}", + db.module_ir(file_id) + .llvm_module + .print_to_string() + .to_string() + ) + }; + insta::assert_snapshot!(insta::_macro_support::AutoName, name, &text); +} diff --git a/crates/mun_compiler/src/lib.rs b/crates/mun_compiler/src/lib.rs index 360460133..081604844 100644 --- a/crates/mun_compiler/src/lib.rs +++ b/crates/mun_compiler/src/lib.rs @@ -1,3 +1,5 @@ +#![allow(clippy::enum_variant_names)] // This is a HACK because we use salsa + ///! This library contains the code required to go from source code to binaries. mod diagnostic; diff --git a/crates/mun_hir/Cargo.toml b/crates/mun_hir/Cargo.toml index d0802e083..641bb7d18 100644 --- a/crates/mun_hir/Cargo.toml +++ b/crates/mun_hir/Cargo.toml @@ -12,4 +12,8 @@ rustc-hash = "1.0" once_cell = "0.2" relative-path = "0.4.0" ena = "0.13" -drop_bomb = "0.1.4" \ No newline at end of file +drop_bomb = "0.1.4" +either = "1.5.3" + +[dev-dependencies] +insta = "0.12.0" diff --git a/crates/mun_hir/src/arena/map.rs b/crates/mun_hir/src/arena/map.rs index fe406d247..02c51c3d8 100644 --- a/crates/mun_hir/src/arena/map.rs +++ b/crates/mun_hir/src/arena/map.rs @@ -13,7 +13,7 @@ pub struct ArenaMap { impl ArenaMap { pub fn insert(&mut self, id: ID, t: T) { - let idx = Self::to_idx(id); + let idx = to_idx(id); if self.v.capacity() <= idx { self.v.reserve(idx + 1 - self.v.capacity()); } @@ -26,11 +26,11 @@ impl ArenaMap { } pub fn get(&self, id: ID) -> Option<&T> { - self.v.get(Self::to_idx(id)).and_then(|it| it.as_ref()) + self.v.get(to_idx(id)).and_then(|it| it.as_ref()) } pub fn get_mut(&mut self, id: ID) -> Option<&mut T> { - self.v.get_mut(Self::to_idx(id)).and_then(|it| it.as_mut()) + self.v.get_mut(to_idx(id)).and_then(|it| it.as_mut()) } pub fn values(&self) -> impl Iterator { @@ -45,29 +45,29 @@ impl ArenaMap { self.v .iter() .enumerate() - .filter_map(|(idx, o)| Some((Self::from_idx(idx), o.as_ref()?))) + .filter_map(|(idx, o)| Some((from_idx(idx), o.as_ref()?))) } pub fn iter_mut(&mut self) -> impl Iterator { self.v .iter_mut() .enumerate() - .filter_map(|(idx, o)| Some((Self::from_idx(idx), o.as_mut()?))) + .filter_map(|(idx, o)| Some((from_idx(idx), o.as_mut()?))) } +} - fn to_idx(id: ID) -> usize { - u32::from(id.into_raw()) as usize - } +fn to_idx(id: ID) -> usize { + u32::from(id.into_raw()) as usize +} - fn from_idx(idx: usize) -> ID { - ID::from_raw((idx as u32).into()) - } +fn from_idx(idx: usize) -> ID { + ID::from_raw((idx as u32).into()) } impl std::ops::Index for ArenaMap { type Output = T; fn index(&self, id: ID) -> &T { - self.v[Self::to_idx(id)].as_ref().unwrap() + self.v[to_idx(id)].as_ref().unwrap() } } diff --git a/crates/mun_hir/src/code_model.rs b/crates/mun_hir/src/code_model.rs index 18f700c73..48f119f81 100644 --- a/crates/mun_hir/src/code_model.rs +++ b/crates/mun_hir/src/code_model.rs @@ -21,7 +21,7 @@ pub struct Module { } impl Module { - pub fn file_id(&self) -> FileId { + pub fn file_id(self) -> FileId { self.file_id } @@ -46,6 +46,7 @@ impl Module { diag.add_to(db, self, sink); } for decl in self.declarations(db) { + #[allow(clippy::single_match)] match decl { ModuleDef::Function(f) => f.diagnostics(db, sink), _ => (), @@ -181,13 +182,13 @@ impl FnData { let mut params = Vec::new(); if let Some(param_list) = src.ast.param_list() { for param in param_list.params() { - let type_ref = type_ref_builder.from_node_opt(param.ascribed_type().as_ref()); + let type_ref = type_ref_builder.alloc_from_node_opt(param.ascribed_type().as_ref()); params.push(type_ref); } } let ret_type = if let Some(type_ref) = src.ast.ret_type().and_then(|rt| rt.type_ref()) { - type_ref_builder.from_node(&type_ref) + type_ref_builder.alloc_from_node(&type_ref) } else { type_ref_builder.unit() }; diff --git a/crates/mun_hir/src/code_model/src.rs b/crates/mun_hir/src/code_model/src.rs index 30c2d38c9..c900ac1fb 100644 --- a/crates/mun_hir/src/code_model/src.rs +++ b/crates/mun_hir/src/code_model/src.rs @@ -1,8 +1,9 @@ use crate::code_model::Function; use crate::ids::AstItemDef; -use crate::{DefDatabase, FileId}; -use mun_syntax::ast; +use crate::{DefDatabase, FileId, SourceDatabase}; +use mun_syntax::{ast, AstNode, SyntaxNode}; +#[derive(Debug, PartialEq, Eq, Clone, Copy)] pub struct Source { pub file_id: FileId, pub ast: T, @@ -19,3 +20,16 @@ impl HasSource for Function { self.id.source(db) } } + +impl Source { + pub(crate) fn map U, U>(self, f: F) -> Source { + Source { + file_id: self.file_id, + ast: f(self.ast), + } + } + + pub(crate) fn file_syntax(&self, db: &impl SourceDatabase) -> SyntaxNode { + db.parse(self.file_id).tree().syntax().clone() + } +} diff --git a/crates/mun_hir/src/db.rs b/crates/mun_hir/src/db.rs index 79270dec7..ea74cca92 100644 --- a/crates/mun_hir/src/db.rs +++ b/crates/mun_hir/src/db.rs @@ -1,3 +1,5 @@ +#![allow(clippy::type_repetition_in_bounds)] + use crate::name_resolution::Namespace; use crate::ty::{FnSig, Ty, TypableDef}; use crate::{ diff --git a/crates/mun_hir/src/diagnostics.rs b/crates/mun_hir/src/diagnostics.rs index 8be0f03c3..dcfa781b5 100644 --- a/crates/mun_hir/src/diagnostics.rs +++ b/crates/mun_hir/src/diagnostics.rs @@ -37,8 +37,10 @@ impl dyn Diagnostic { } } +type DiagnosticCallback<'a> = Box Result<(), ()> + 'a>; + pub struct DiagnosticSink<'a> { - callbacks: Vec Result<(), ()> + 'a>>, + callbacks: Vec>, default_callback: Box, } @@ -202,6 +204,57 @@ impl Diagnostic for MismatchedType { } } +#[derive(Debug)] +pub struct IncompatibleBranch { + pub file: FileId, + pub if_expr: SyntaxNodePtr, + pub expected: Ty, + pub found: Ty, +} + +impl Diagnostic for IncompatibleBranch { + fn message(&self) -> String { + "mismatched branches".to_string() + } + + fn file(&self) -> FileId { + self.file + } + + fn syntax_node_ptr(&self) -> SyntaxNodePtr { + self.if_expr + } + + fn as_any(&self) -> &(dyn Any + Send + 'static) { + self + } +} + +#[derive(Debug)] +pub struct MissingElseBranch { + pub file: FileId, + pub if_expr: SyntaxNodePtr, + pub found: Ty, +} + +impl Diagnostic for MissingElseBranch { + fn message(&self) -> String { + "missing else branch".to_string() + } + + fn file(&self) -> FileId { + self.file + } + + fn syntax_node_ptr(&self) -> SyntaxNodePtr { + self.if_expr + } + + fn as_any(&self) -> &(dyn Any + Send + 'static) { + self + } +} + #[derive(Debug)] pub struct CannotApplyBinaryOp { pub file: FileId, diff --git a/crates/mun_hir/src/expr.rs b/crates/mun_hir/src/expr.rs index 3e193dbe5..0d1cf3291 100644 --- a/crates/mun_hir/src/expr.rs +++ b/crates/mun_hir/src/expr.rs @@ -6,12 +6,12 @@ use crate::{ }; //pub use mun_syntax::ast::PrefixOp as UnaryOp; -use crate::code_model::src::HasSource; +use crate::code_model::src::{HasSource, Source}; use crate::name::AsName; use crate::type_ref::{TypeRef, TypeRefBuilder, TypeRefId, TypeRefMap, TypeRefSourceMap}; pub use mun_syntax::ast::PrefixOp as UnaryOp; use mun_syntax::ast::{ArgListOwner, BinOp, NameOwner, TypeAscriptionOwner}; -use mun_syntax::{ast, AstNode, AstPtr, SyntaxNodePtr, T}; +use mun_syntax::{ast, AstNode, AstPtr, T}; use rustc_hash::FxHashMap; use std::ops::Index; use std::sync::Arc; @@ -102,21 +102,27 @@ impl Index for Body { } } +type ExprPtr = AstPtr; //Either, AstPtr>; +type ExprSource = Source; + +type PatPtr = AstPtr; //Either, AstPtr>; +type PatSource = Source; + /// An item body together with the mapping from syntax nodes to HIR expression Ids. This is needed /// to go from e.g. a position in a file to the HIR expression containing it; but for type /// inference etc., we want to operate on a structure that is agnostic to the action positions of /// expressions in the file, so that we don't recompute types whenever some whitespace is typed. #[derive(Default, Debug, Eq, PartialEq)] pub struct BodySourceMap { - expr_map: FxHashMap, - expr_map_back: ArenaMap, - pat_map: FxHashMap, PatId>, - pat_map_back: ArenaMap>, + expr_map: FxHashMap, + expr_map_back: ArenaMap, + pat_map: FxHashMap, + pat_map_back: ArenaMap, type_refs: TypeRefSourceMap, } impl BodySourceMap { - pub(crate) fn expr_syntax(&self, expr: ExprId) -> Option { + pub(crate) fn expr_syntax(&self, expr: ExprId) -> Option { self.expr_map_back.get(expr).cloned() } @@ -124,17 +130,15 @@ impl BodySourceMap { self.type_refs.type_ref_syntax(type_ref) } - pub(crate) fn syntax_expr(&self, ptr: SyntaxNodePtr) -> Option { + pub(crate) fn syntax_expr(&self, ptr: ExprPtr) -> Option { self.expr_map.get(&ptr).cloned() } pub(crate) fn node_expr(&self, node: &ast::Expr) -> Option { - self.expr_map - .get(&SyntaxNodePtr::new(node.syntax())) - .cloned() + self.expr_map.get(&AstPtr::new(node)).cloned() } - pub(crate) fn pat_syntax(&self, pat: PatId) -> Option> { + pub(crate) fn pat_syntax(&self, pat: PatId) -> Option { self.pat_map_back.get(pat).cloned() } @@ -176,6 +180,11 @@ pub enum Expr { args: Vec, }, Path(Path), + If { + condition: ExprId, + then_branch: ExprId, + else_branch: Option, + }, UnaryOp { expr: ExprId, op: UnaryOp, @@ -262,17 +271,28 @@ impl Expr { f(*expr); } Expr::Literal(_) => {} + Expr::If { + condition, + then_branch, + else_branch, + } => { + f(*condition); + f(*then_branch); + if let Some(else_expr) = else_branch { + f(*else_expr); + } + } } } } -// Similar to ast::PatKind +/// Similar to `ast::PatKind` #[derive(Debug, Clone, Eq, PartialEq)] pub enum Pat { - Missing, - Wild, - Path(Path), - Bind { name: Name }, + Missing, // Indicates an error + Wild, // `_` + Path(Path), // E.g. `foo::bar` + Bind { name: Name }, // E.g. `a` } impl Pat { @@ -293,13 +313,14 @@ pub(crate) struct ExprCollector { body_expr: Option, ret_type: Option, type_ref_builder: TypeRefBuilder, + current_file_id: FileId, } impl<'a, DB> ExprCollector<&'a DB> where DB: HirDatabase, { - pub fn new(owner: DefWithBody, _file_id: FileId, db: &'a DB) -> Self { + pub fn new(owner: DefWithBody, file_id: FileId, db: &'a DB) -> Self { ExprCollector { owner, db, @@ -310,23 +331,40 @@ where body_expr: None, ret_type: None, type_ref_builder: TypeRefBuilder::default(), + current_file_id: file_id, } } - fn alloc_pat(&mut self, pat: Pat, ptr: AstPtr) -> PatId { + fn alloc_pat(&mut self, pat: Pat, ptr: PatPtr) -> PatId { let id = self.pats.alloc(pat); self.source_map.pat_map.insert(ptr, id); - self.source_map.pat_map_back.insert(id, ptr); + self.source_map.pat_map_back.insert( + id, + Source { + file_id: self.current_file_id, + ast: ptr, + }, + ); id } - fn alloc_expr(&mut self, expr: Expr, syntax_ptr: SyntaxNodePtr) -> ExprId { + fn alloc_expr(&mut self, expr: Expr, ptr: ExprPtr) -> ExprId { let id = self.exprs.alloc(expr); - self.source_map.expr_map.insert(syntax_ptr, id); - self.source_map.expr_map_back.insert(id, syntax_ptr); + self.source_map.expr_map.insert(ptr, id); + self.source_map.expr_map_back.insert( + id, + Source { + file_id: self.current_file_id, + ast: ptr, + }, + ); id } + fn missing_expr(&mut self) -> ExprId { + self.exprs.alloc(Expr::Missing) + } + fn collect_fn_body(&mut self, node: &ast::FunctionDef) { if let Some(param_list) = node.param_list() { for param in param_list.params() { @@ -338,7 +376,7 @@ where let param_pat = self.collect_pat(pat); let param_type = self .type_ref_builder - .from_node_opt(param.ascribed_type().as_ref()); + .alloc_from_node_opt(param.ascribed_type().as_ref()); self.params.push((param_pat, param_type)); } } @@ -347,14 +385,14 @@ where self.body_expr = Some(body); let ret_type = if let Some(type_ref) = node.ret_type().and_then(|rt| rt.type_ref()) { - self.type_ref_builder.from_node(&type_ref) + self.type_ref_builder.alloc_from_node(&type_ref) } else { self.type_ref_builder.unit() }; self.ret_type = Some(ret_type); } - fn collect_block_opt(&mut self, block: Option) -> ExprId { + fn collect_block_opt(&mut self, block: Option) -> ExprId { if let Some(block) = block { self.collect_block(block) } else { @@ -362,7 +400,8 @@ where } } - fn collect_block(&mut self, block: ast::Block) -> ExprId { + fn collect_block(&mut self, block: ast::BlockExpr) -> ExprId { + let syntax_node_ptr = AstPtr::new(&block.clone().into()); let statements = block .statements() .map(|s| match s.kind() { @@ -370,7 +409,7 @@ where let pat = self.collect_pat_opt(stmt.pat()); let type_ref = stmt .ascribed_type() - .map(|t| self.type_ref_builder.from_node(&t)); + .map(|t| self.type_ref_builder.alloc_from_node(&t)); let initializer = stmt.initializer().map(|e| self.collect_expr(e)); Statement::Let { pat, @@ -384,10 +423,7 @@ where }) .collect(); let tail = block.expr().map(|e| self.collect_expr(e)); - self.alloc_expr( - Expr::Block { statements, tail }, - SyntaxNodePtr::new(block.syntax()), - ) + self.alloc_expr(Expr::Block { statements, tail }, syntax_node_ptr) } fn collect_pat_opt(&mut self, pat: Option) -> PatId { @@ -407,8 +443,9 @@ where } fn collect_expr(&mut self, expr: ast::Expr) -> ExprId { - let syntax_ptr = SyntaxNodePtr::new(expr.syntax()); + let syntax_ptr = AstPtr::new(&expr.clone()); match expr.kind() { + ast::ExprKind::BlockExpr(b) => self.collect_block(b), ast::ExprKind::Literal(e) => { let lit = match e.kind() { ast::LiteralKind::Bool => Literal::Bool(e.syntax().kind() == T![true]), @@ -536,6 +573,34 @@ where .unwrap_or(Expr::Missing); self.alloc_expr(path, syntax_ptr) } + ast::ExprKind::IfExpr(e) => { + let then_branch = self.collect_block_opt(e.then_branch()); + + let else_branch = e.else_branch().map(|b| match b { + ast::ElseBranch::Block(it) => self.collect_block(it), + ast::ElseBranch::IfExpr(elif) => { + let expr = ast::Expr::cast(elif.syntax().clone()).unwrap(); + self.collect_expr(expr) + } + }); + + let condition = match e.condition() { + None => self.missing_expr(), + Some(condition) => match condition.pat() { + None => self.collect_expr_opt(condition.expr()), + _ => unreachable!("patterns in conditions are not yet supported"), + }, + }; + + self.alloc_expr( + Expr::If { + condition, + then_branch, + else_branch, + }, + syntax_ptr, + ) + } ast::ExprKind::ParenExpr(e) => { let inner = self.collect_expr_opt(e.expr()); // make the paren expr point to the inner expression as well diff --git a/crates/mun_hir/src/lib.rs b/crates/mun_hir/src/lib.rs index d31755a0b..fd794d9eb 100644 --- a/crates/mun_hir/src/lib.rs +++ b/crates/mun_hir/src/lib.rs @@ -22,6 +22,9 @@ mod source_id; mod ty; mod type_ref; +#[cfg(test)] +mod mock; + pub use salsa; pub use crate::{ diff --git a/crates/mun_hir/src/mock.rs b/crates/mun_hir/src/mock.rs new file mode 100644 index 000000000..db86f9ced --- /dev/null +++ b/crates/mun_hir/src/mock.rs @@ -0,0 +1,34 @@ +use crate::db::SourceDatabase; +use crate::{FileId, PackageInput, RelativePathBuf}; +use std::sync::Arc; + +/// A mock implementation of the IR database. It can be used to set up a simple test case. +#[salsa::database( + crate::SourceDatabaseStorage, + crate::DefDatabaseStorage, + crate::HirDatabaseStorage +)] +#[derive(Default, Debug)] +pub(crate) struct MockDatabase { + runtime: salsa::Runtime, +} + +impl salsa::Database for MockDatabase { + fn salsa_runtime(&self) -> &salsa::Runtime { + &self.runtime + } +} + +impl MockDatabase { + /// Creates a database from the given text. + pub fn with_single_file(text: &str) -> (MockDatabase, FileId) { + let mut db: MockDatabase = Default::default(); + let file_id = FileId(0); + db.set_file_relative_path(file_id, RelativePathBuf::from("main.mun")); + db.set_file_text(file_id, Arc::new(text.to_string())); + let mut package_input = PackageInput::default(); + package_input.add_module(file_id); + db.set_package_input(Arc::new(package_input)); + (db, file_id) + } +} diff --git a/crates/mun_hir/src/name_resolution.rs b/crates/mun_hir/src/name_resolution.rs index c8b12a1c1..e5fa777bb 100644 --- a/crates/mun_hir/src/name_resolution.rs +++ b/crates/mun_hir/src/name_resolution.rs @@ -47,6 +47,7 @@ pub(crate) fn module_scope_query(db: &impl HirDatabase, file_id: FileId) -> Arc< let mut scope = ModuleScope::default(); let defs = db.module_data(file_id); for def in defs.definitions() { + #[allow(clippy::single_match)] match def { ModuleDef::Function(f) => { scope.items.insert( diff --git a/crates/mun_hir/src/path.rs b/crates/mun_hir/src/path.rs index 0592f4cfd..9f9b468c1 100644 --- a/crates/mun_hir/src/path.rs +++ b/crates/mun_hir/src/path.rs @@ -25,31 +25,31 @@ impl Path { pub fn from_ast(path: ast::Path) -> Option { let mut kind = PathKind::Plain; let mut segments = Vec::new(); - loop { - let segment = path.segment()?; + // loop { + let segment = path.segment()?; - if segment.has_colon_colon() { - kind = PathKind::Abs; - } + if segment.has_colon_colon() { + kind = PathKind::Abs; + } - match segment.kind()? { - ast::PathSegmentKind::Name(name) => { - let segment = PathSegment { - name: name.as_name(), - }; - segments.push(segment); - } - ast::PathSegmentKind::SelfKw => { - kind = PathKind::Self_; - break; - } - ast::PathSegmentKind::SuperKw => { - kind = PathKind::Super; - break; - } + match segment.kind()? { + ast::PathSegmentKind::Name(name) => { + let segment = PathSegment { + name: name.as_name(), + }; + segments.push(segment); + } + ast::PathSegmentKind::SelfKw => { + kind = PathKind::Self_; + // break; + } + ast::PathSegmentKind::SuperKw => { + kind = PathKind::Super; + // break; } - break; } + // break; + // } segments.reverse(); Some(Path { kind, segments }) } diff --git a/crates/mun_hir/src/source_id.rs b/crates/mun_hir/src/source_id.rs index bd1d214d1..e9bdb1614 100644 --- a/crates/mun_hir/src/source_id.rs +++ b/crates/mun_hir/src/source_id.rs @@ -33,11 +33,11 @@ impl Hash for AstId { } impl AstId { - pub(crate) fn file_id(&self) -> FileId { + pub(crate) fn file_id(self) -> FileId { self.file_id } - pub(crate) fn to_node(&self, db: &impl DefDatabase) -> N { + pub(crate) fn to_node(self, db: &impl DefDatabase) -> N { let syntax_node = db.ast_id_to_node(self.file_id, self.file_ast_id.raw); N::cast(syntax_node).unwrap() } diff --git a/crates/mun_hir/src/ty.rs b/crates/mun_hir/src/ty.rs index 221ef0264..7288abb3d 100644 --- a/crates/mun_hir/src/ty.rs +++ b/crates/mun_hir/src/ty.rs @@ -12,6 +12,9 @@ use std::sync::Arc; mod op; +#[cfg(test)] +mod tests; + /// This should be cheap to clone. #[derive(Clone, PartialEq, Eq, Debug, Hash)] pub enum Ty { @@ -52,6 +55,9 @@ pub enum TypeCtor { /// The primitive boolean type. Written as `bool`. Bool, + /// The never type `never`. + Never, + /// The anonymous type of a function declaration/definition. Each /// function has a unique type, which is output (for a function /// named `foo` returning an `number`) as `fn() -> number {foo}`. @@ -86,6 +92,13 @@ impl Ty { *self == Ty::Empty } + pub fn is_never(&self) -> bool { + match self.as_simple() { + Some(TypeCtor::Never) => true, + _ => false, + } + } + /// Returns the function definition for the given expression or `None` if the type does not /// represent a function. pub fn as_function_def(&self) -> Option { @@ -150,31 +163,30 @@ impl FnSig { impl HirDisplay for Ty { fn hir_fmt(&self, f: &mut HirFormatter) -> fmt::Result { match self { - Ty::Apply(a_ty) => a_ty.hir_fmt(f)?, - Ty::Unknown => write!(f, "{{unknown}}")?, - Ty::Empty => write!(f, "nothing")?, - Ty::Infer(tv) => write!(f, "'{}", tv.0)?, + Ty::Apply(a_ty) => a_ty.hir_fmt(f), + Ty::Unknown => write!(f, "{{unknown}}"), + Ty::Empty => write!(f, "nothing"), + Ty::Infer(tv) => write!(f, "'{}", tv.0), } - Ok(()) } } impl HirDisplay for ApplicationTy { fn hir_fmt(&self, f: &mut HirFormatter) -> fmt::Result { match self.ctor { - TypeCtor::Float => write!(f, "float")?, - TypeCtor::Int => write!(f, "int")?, - TypeCtor::Bool => write!(f, "bool")?, + TypeCtor::Float => write!(f, "float"), + TypeCtor::Int => write!(f, "int"), + TypeCtor::Bool => write!(f, "bool"), + TypeCtor::Never => write!(f, "never"), TypeCtor::FnDef(def) => { let sig = f.db.fn_signature(def); let name = def.name(f.db); write!(f, "function {}", name)?; write!(f, "(")?; f.write_joined(sig.params(), ", ")?; - write!(f, ") -> {}", sig.ret().display(f.db))?; + write!(f, ") -> {}", sig.ret().display(f.db)) } } - Ok(()) } } diff --git a/crates/mun_hir/src/ty/infer.rs b/crates/mun_hir/src/ty/infer.rs index 7b8a141b2..d758bad58 100644 --- a/crates/mun_hir/src/ty/infer.rs +++ b/crates/mun_hir/src/ty/infer.rs @@ -22,12 +22,29 @@ mod type_variable; pub use type_variable::TypeVarId; +macro_rules! ty_app { + ($ctor:pat, $param:pat) => { + $crate::ty::Ty::Apply($crate::ty::ApplicationTy { + ctor: $ctor, + parameters: $param, + }) + }; + ($ctor:pat) => { + $crate::ty::Ty::Apply($crate::ty::ApplicationTy { + ctor: $ctor, + .. + }) + }; +} + +mod coerce; + /// The result of type inference: A mapping from expressions and patterns to types. #[derive(Clone, PartialEq, Eq, Debug)] pub struct InferenceResult { - type_of_expr: ArenaMap, - type_of_pat: ArenaMap, - diagnostics: Vec, + pub(crate) type_of_expr: ArenaMap, + pub(crate) type_of_pat: ArenaMap, + pub(crate) diagnostics: Vec, } impl Index for InferenceResult { @@ -117,7 +134,7 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { /// Given a `TypeRefId`, resolve the reference to an actual `Ty`. If the the type could not /// be resolved an error is emitted and `Ty::Error` is returned. - fn resolve_type(&mut self, type_ref: &TypeRefId) -> Ty { + fn resolve_type(&mut self, type_ref: TypeRefId) -> Ty { // Try to resolve the type from the Hir let result = Ty::from_hir( self.db, @@ -141,6 +158,25 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { } } +impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { + /// Unify the specified types, returns true if successful; false otherwise. + fn unify(&mut self, ty1: &Ty, ty2: &Ty) -> bool { + if ty1 == ty2 { + return true; + } + + self.unify_inner_trivial(&ty1, &ty2) + } + + /// This function performs trivial unifications. Returns true if a unification took place; + fn unify_inner_trivial(&mut self, ty1: &Ty, ty2: &Ty) -> bool { + match (ty1, ty2) { + (Ty::Unknown, _) | (_, Ty::Unknown) => true, + _ => false, + } + } +} + impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { /// Collect all the parameter patterns from the body. After calling this method the `return_ty` /// will have a valid value, also all parameters are added inferred. @@ -150,17 +186,18 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { // Iterate over all the parameters and associated types of the body and infer the types of // the parameters. for (pat, type_ref) in body.params().iter() { - let ty = self.resolve_type(type_ref); + let ty = self.resolve_type(*type_ref); self.infer_pat(*pat, ty); } // Resolve the return type - self.return_ty = self.resolve_type(&body.ret_type()) + self.return_ty = self.resolve_type(body.ret_type()) } /// Record the type of the specified pattern and all sub-patterns. fn infer_pat(&mut self, pat: PatId, ty: Ty) { let body = Arc::clone(&self.body); // avoid borrow checker problem + #[allow(clippy::single_match)] match &body[pat] { Pat::Bind { .. } => { self.set_pat_type(pat, ty); @@ -188,6 +225,44 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { self.infer_path_expr(&resolver, p, tgt_expr.into()) .unwrap_or(Ty::Unknown) } + Expr::If { + condition, + then_branch, + else_branch, + } => { + self.infer_expr( + *condition, + &Expectation::has_type(Ty::simple(TypeCtor::Bool)), + ); + let then_ty = self.infer_expr(*then_branch, &expected); + match else_branch { + Some(else_branch) => { + let else_ty = self.infer_expr(*else_branch, &expected); + match self.coerce_merge_branch(&then_ty, &else_ty) { + Some(ty) => ty, + None => { + self.diagnostics + .push(InferenceDiagnostic::IncompatibleBranches { + id: tgt_expr, + then_ty: then_ty.clone(), + else_ty: else_ty.clone(), + }); + then_ty + } + } + } + None => { + if !self.coerce(&then_ty, &Ty::Empty) { + self.diagnostics + .push(InferenceDiagnostic::MissingElseBranch { + id: tgt_expr, + then_ty: then_ty.clone(), + }) + } + Ty::Empty + } + } + } Expr::BinaryOp { lhs, rhs, op } => match op { Some(op) => { let lhs_ty = self.infer_expr(*lhs, &Expectation::none()); @@ -198,10 +273,10 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { _ => Ty::Unknown, }, Expr::Block { statements, tail } => self.infer_block(statements, *tail, expected), - Expr::Call { callee: call, args } => self.infer_call(&tgt_expr, call, args, expected), + Expr::Call { callee: call, args } => self.infer_call(tgt_expr, *call, args, expected), Expr::Literal(lit) => match lit { Literal::String(_) => Ty::Unknown, - Literal::Bool(_) => Ty::Unknown, + Literal::Bool(_) => Ty::simple(TypeCtor::Bool), Literal::Int(_) => Ty::simple(TypeCtor::Int), Literal::Float(_) => Ty::simple(TypeCtor::Float), }, @@ -226,18 +301,18 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { /// Inferences the type of a call expression. fn infer_call( &mut self, - tgt_expr: &ExprId, - callee: &ExprId, - args: &Vec, + tgt_expr: ExprId, + callee: ExprId, + args: &[ExprId], _expected: &Expectation, ) -> Ty { - let callee_ty = self.infer_expr(*callee, &Expectation::none()); + let callee_ty = self.infer_expr(callee, &Expectation::none()); let (param_tys, ret_ty) = match callee_ty.callable_sig(self.db) { Some(sig) => (sig.params().to_vec(), sig.ret().clone()), None => { self.diagnostics .push(InferenceDiagnostic::ExpectedFunction { - id: *callee, + id: callee, found: callee_ty, }); (Vec::new(), Ty::Unknown) @@ -248,11 +323,11 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { } /// Checks whether the specified passed arguments match the parameters of a callable definition. - fn check_call_arguments(&mut self, tgt_expr: &ExprId, args: &[ExprId], param_tys: &[Ty]) { + fn check_call_arguments(&mut self, tgt_expr: ExprId, args: &[ExprId], param_tys: &[Ty]) { if args.len() != param_tys.len() { self.diagnostics .push(InferenceDiagnostic::ParameterCountMismatch { - id: *tgt_expr, + id: tgt_expr, found: args.len(), expected: param_tys.len(), }) @@ -335,7 +410,7 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { } => { let decl_ty = type_ref .as_ref() - .map(|tr| self.resolve_type(tr)) + .map(|tr| self.resolve_type(*tr)) .unwrap_or(Ty::Unknown); //let decl_ty = self.insert_type_vars(decl_ty); let ty = if let Some(expr) = initializer { @@ -392,10 +467,14 @@ impl Expectation { fn none() -> Self { Expectation { ty: Ty::Unknown } } + + fn is_none(&self) -> bool { + self.ty == Ty::Unknown + } } #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] -pub(super) enum ExprOrPatId { +pub(crate) enum ExprOrPatId { ExprId(ExprId), PatId(PatId), } @@ -414,7 +493,8 @@ impl From for ExprOrPatId { mod diagnostics { use crate::diagnostics::{ - CannotApplyBinaryOp, ExpectedFunction, MismatchedType, ParameterCountMismatch, + CannotApplyBinaryOp, ExpectedFunction, IncompatibleBranch, MismatchedType, + MissingElseBranch, ParameterCountMismatch, }; use crate::{ code_model::src::HasSource, @@ -425,7 +505,7 @@ mod diagnostics { }; #[derive(Debug, PartialEq, Eq, Clone)] - pub(super) enum InferenceDiagnostic { + pub(crate) enum InferenceDiagnostic { UnresolvedValue { id: ExprOrPatId, }, @@ -446,6 +526,15 @@ mod diagnostics { expected: Ty, found: Ty, }, + IncompatibleBranches { + id: ExprId, + then_ty: Ty, + else_ty: Ty, + }, + MissingElseBranch { + id: ExprId, + then_ty: Ty, + }, CannotApplyBinaryOp { id: ExprId, lhs: Ty, @@ -454,7 +543,7 @@ mod diagnostics { } impl InferenceDiagnostic { - pub(super) fn add_to( + pub(crate) fn add_to( &self, db: &impl HirDatabase, owner: Function, @@ -465,9 +554,11 @@ mod diagnostics { let file = owner.source(db).file_id; let body = owner.body_source_map(db); let expr = match id { - ExprOrPatId::ExprId(id) => body.expr_syntax(*id), + ExprOrPatId::ExprId(id) => { + body.expr_syntax(*id).map(|ptr| ptr.ast.syntax_node_ptr()) + } ExprOrPatId::PatId(id) => { - body.pat_syntax(*id).map(|ptr| ptr.syntax_node_ptr()) + body.pat_syntax(*id).map(|ptr| ptr.ast.syntax_node_ptr()) } } .unwrap(); @@ -487,7 +578,7 @@ mod diagnostics { } => { let file = owner.source(db).file_id; let body = owner.body_source_map(db); - let expr = body.expr_syntax(*id).unwrap(); + let expr = body.expr_syntax(*id).unwrap().ast.syntax_node_ptr(); sink.push(ParameterCountMismatch { file, expr, @@ -498,7 +589,7 @@ mod diagnostics { InferenceDiagnostic::ExpectedFunction { id, found } => { let file = owner.source(db).file_id; let body = owner.body_source_map(db); - let expr = body.expr_syntax(*id).unwrap(); + let expr = body.expr_syntax(*id).unwrap().ast.syntax_node_ptr(); sink.push(ExpectedFunction { file, expr, @@ -512,7 +603,7 @@ mod diagnostics { } => { let file = owner.source(db).file_id; let body = owner.body_source_map(db); - let expr = body.expr_syntax(*id).unwrap(); + let expr = body.expr_syntax(*id).unwrap().ast.syntax_node_ptr(); sink.push(MismatchedType { file, expr, @@ -520,10 +611,35 @@ mod diagnostics { expected: expected.clone(), }); } + InferenceDiagnostic::IncompatibleBranches { + id, + then_ty, + else_ty, + } => { + let file = owner.source(db).file_id; + let body = owner.body_source_map(db); + let expr = body.expr_syntax(*id).unwrap().ast.syntax_node_ptr(); + sink.push(IncompatibleBranch { + file, + if_expr: expr, + expected: then_ty.clone(), + found: else_ty.clone(), + }); + } + InferenceDiagnostic::MissingElseBranch { id, then_ty } => { + let file = owner.source(db).file_id; + let body = owner.body_source_map(db); + let expr = body.expr_syntax(*id).unwrap().ast.syntax_node_ptr(); + sink.push(MissingElseBranch { + file, + if_expr: expr, + found: then_ty.clone(), + }); + } InferenceDiagnostic::CannotApplyBinaryOp { id, lhs, rhs } => { let file = owner.source(db).file_id; let body = owner.body_source_map(db); - let expr = body.expr_syntax(*id).unwrap(); + let expr = body.expr_syntax(*id).unwrap().ast.syntax_node_ptr(); sink.push(CannotApplyBinaryOp { file, expr, diff --git a/crates/mun_hir/src/ty/infer/coerce.rs b/crates/mun_hir/src/ty/infer/coerce.rs new file mode 100644 index 000000000..c2d715516 --- /dev/null +++ b/crates/mun_hir/src/ty/infer/coerce.rs @@ -0,0 +1,34 @@ +use super::InferenceResultBuilder; +use crate::{HirDatabase, Ty, TypeCtor}; + +impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { + /// Unify two types, but may coerce the first one to the second using implicit coercion rules if + /// needed. + pub(super) fn coerce(&mut self, from_ty: &Ty, to_ty: &Ty) -> bool { + self.coerce_inner(from_ty.clone(), &to_ty) + } + + /// Merge two types from different branches, with possible implicit coerce. + pub(super) fn coerce_merge_branch(&mut self, ty1: &Ty, ty2: &Ty) -> Option { + if self.coerce(ty1, ty2) { + Some(ty2.clone()) + } else if self.coerce(ty2, ty1) { + Some(ty1.clone()) + } else { + None + } + } + + fn coerce_inner(&mut self, from_ty: Ty, to_ty: &Ty) -> bool { + match (&from_ty, to_ty) { + (ty_app!(TypeCtor::Never), ..) => return true, + _ => { + if self.unify_inner_trivial(&from_ty, &to_ty) { + return true; + } + } + }; + + self.unify(&from_ty, to_ty) + } +} diff --git a/crates/mun_hir/src/ty/lower.rs b/crates/mun_hir/src/ty/lower.rs index 2e11ac394..085dafa39 100644 --- a/crates/mun_hir/src/ty/lower.rs +++ b/crates/mun_hir/src/ty/lower.rs @@ -17,7 +17,7 @@ impl Ty { db: &impl HirDatabase, resolver: &Resolver, type_ref_map: &TypeRefMap, - type_ref: &TypeRefId, + type_ref: TypeRefId, ) -> LowerResult { let mut diagnostics = Vec::new(); let ty = @@ -30,19 +30,18 @@ impl Ty { resolver: &Resolver, type_ref_map: &TypeRefMap, diagnostics: &mut Vec, - type_ref: &TypeRefId, + type_ref: TypeRefId, ) -> Ty { - let res = match &type_ref_map[*type_ref] { + let res = match &type_ref_map[type_ref] { TypeRef::Path(path) => Ty::from_hir_path(db, resolver, path), TypeRef::Error => Some(Ty::Unknown), TypeRef::Empty => Some(Ty::Empty), + TypeRef::Never => Some(Ty::simple(TypeCtor::Never)), }; if let Some(ty) = res { ty } else { - diagnostics.push(LowerDiagnostic::UnresolvedType { - id: type_ref.clone(), - }); + diagnostics.push(LowerDiagnostic::UnresolvedType { id: type_ref }); Ty::Unknown } } @@ -138,9 +137,9 @@ pub fn fn_sig_for_fn(db: &impl HirDatabase, def: Function) -> FnSig { let params = data .params() .iter() - .map(|tr| Ty::from_hir(db, &resolver, data.type_ref_map(), tr).ty) + .map(|tr| Ty::from_hir(db, &resolver, data.type_ref_map(), *tr).ty) .collect::>(); - let ret = Ty::from_hir(db, &resolver, data.type_ref_map(), data.ret_type()).ty; + let ret = Ty::from_hir(db, &resolver, data.type_ref_map(), *data.ret_type()).ty; FnSig::from_params_and_return(params, ret) } diff --git a/crates/mun_hir/src/ty/snapshots/tests__infer_basics.snap b/crates/mun_hir/src/ty/snapshots/tests__infer_basics.snap new file mode 100644 index 000000000..71127e602 --- /dev/null +++ b/crates/mun_hir/src/ty/snapshots/tests__infer_basics.snap @@ -0,0 +1,13 @@ +--- +source: crates/mun_hir/src/ty/tests.rs +expression: "fn test(a:int, b:float, c:never, d:bool): bool {\n a;\n b;\n c;\n d\n}" +--- +[8; 9) 'a': int +[15; 16) 'b': float +[24; 25) 'c': never +[33; 34) 'd': bool +[47; 77) '{ ... d }': bool +[53; 54) 'a': int +[60; 61) 'b': float +[67; 68) 'c': never +[74; 75) 'd': bool diff --git a/crates/mun_hir/src/ty/snapshots/tests__infer_branching.snap b/crates/mun_hir/src/ty/snapshots/tests__infer_branching.snap new file mode 100644 index 000000000..ae137c554 --- /dev/null +++ b/crates/mun_hir/src/ty/snapshots/tests__infer_branching.snap @@ -0,0 +1,42 @@ +--- +source: crates/mun_hir/src/ty/tests.rs +expression: "fn test() {\n let a = if true { 3 } else { 4 }\n let b = if true { 3 } // Missing else branch\n let c = if true { 3; }\n let d = if true { 5 } else if false { 3 } else { 4 }\n let e = if true { 5.0 } else { 5 } // Mismatched branches\n}" +--- +[61; 74): missing else branch +[208; 234): mismatched branches +[10; 260) '{ ...ches }': nothing +[20; 21) 'a': int +[24; 48) 'if tru... { 4 }': int +[27; 31) 'true': bool +[32; 37) '{ 3 }': int +[34; 35) '3': int +[43; 48) '{ 4 }': int +[45; 46) '4': int +[57; 58) 'b': nothing +[61; 74) 'if true { 3 }': nothing +[64; 68) 'true': bool +[69; 74) '{ 3 }': int +[71; 72) '3': int +[120; 121) 'c': nothing +[124; 138) 'if true { 3; }': nothing +[127; 131) 'true': bool +[132; 138) '{ 3; }': nothing +[134; 135) '3': int +[147; 148) 'd': int +[151; 195) 'if tru... { 4 }': int +[154; 158) 'true': bool +[159; 164) '{ 5 }': int +[161; 162) '5': int +[170; 195) 'if fal... { 4 }': int +[173; 178) 'false': bool +[179; 184) '{ 3 }': int +[181; 182) '3': int +[190; 195) '{ 4 }': int +[192; 193) '4': int +[204; 205) 'e': float +[208; 234) 'if tru... { 5 }': float +[211; 215) 'true': bool +[216; 223) '{ 5.0 }': float +[218; 221) '5.0': float +[229; 234) '{ 5 }': int +[231; 232) '5': int diff --git a/crates/mun_hir/src/ty/snapshots/tests__void_return.snap b/crates/mun_hir/src/ty/snapshots/tests__void_return.snap new file mode 100644 index 000000000..c148fc7de --- /dev/null +++ b/crates/mun_hir/src/ty/snapshots/tests__void_return.snap @@ -0,0 +1,12 @@ +--- +source: crates/mun_hir/src/ty/tests.rs +expression: "fn bar() {\n let a = 3;\n}\nfn foo(a:int) {\n let c = bar()\n}" +--- +[9; 27) '{ ...= 3; }': nothing +[19; 20) 'a': int +[23; 24) '3': int +[35; 36) 'a': int +[42; 63) '{ ...ar() }': nothing +[52; 53) 'c': nothing +[56; 59) 'bar': function bar() -> nothing +[56; 61) 'bar()': nothing diff --git a/crates/mun_hir/src/ty/tests.rs b/crates/mun_hir/src/ty/tests.rs new file mode 100644 index 000000000..65605a044 --- /dev/null +++ b/crates/mun_hir/src/ty/tests.rs @@ -0,0 +1,151 @@ +use crate::db::SourceDatabase; +use crate::diagnostics::DiagnosticSink; +use crate::expr::BodySourceMap; +use crate::ids::LocationCtx; +use crate::mock::MockDatabase; +use crate::{Function, HirDisplay, InferenceResult}; +use mun_syntax::{ast, AstNode}; +use std::fmt::Write; +use std::sync::Arc; + +#[test] +fn infer_basics() { + infer_snapshot( + r#" + fn test(a:int, b:float, c:never, d:bool): bool { + a; + b; + c; + d + } + "#, + ) +} + +#[test] +fn infer_branching() { + infer_snapshot( + r#" + fn test() { + let a = if true { 3 } else { 4 } + let b = if true { 3 } // Missing else branch + let c = if true { 3; } + let d = if true { 5 } else if false { 3 } else { 4 } + let e = if true { 5.0 } else { 5 } // Mismatched branches + } + "#, + ) +} + +#[test] +fn void_return() { + infer_snapshot( + r#" + fn bar() { + let a = 3; + } + fn foo(a:int) { + let c = bar() + } + "#, + ) +} + +fn infer_snapshot(text: &str) { + let text = text.trim().replace("\n ", "\n"); + insta::assert_snapshot!(insta::_macro_support::AutoName, infer(&text), &text); +} + +fn infer(content: &str) -> String { + let (db, file_id) = MockDatabase::with_single_file(content); + let source_file = db.parse(file_id).ok().unwrap(); + + let mut acc = String::new(); + + let mut infer_def = |infer_result: Arc, + body_source_map: Arc| { + let mut types = Vec::new(); + + for (pat, ty) in infer_result.type_of_pat.iter() { + let syntax_ptr = match body_source_map.pat_syntax(pat) { + Some(sp) => sp.map(|ast| ast.syntax_node_ptr()), + None => continue, + }; + types.push((syntax_ptr, ty)); + } + + for (expr, ty) in infer_result.type_of_expr.iter() { + let syntax_ptr = match body_source_map.expr_syntax(expr) { + Some(sp) => sp.map(|ast| ast.syntax_node_ptr()), + None => continue, + }; + types.push((syntax_ptr, ty)); + } + + // Sort ranges for consistency + types.sort_by_key(|(src_ptr, _)| (src_ptr.ast.range().start(), src_ptr.ast.range().end())); + for (src_ptr, ty) in &types { + let node = src_ptr.ast.to_node(&src_ptr.file_syntax(&db)); + + let (range, text) = ( + src_ptr.ast.range(), + node.text().to_string().replace("\n", " "), + ); + write!( + acc, + "{} '{}': {}\n", + range, + ellipsize(text, 15), + ty.display(&db) + ) + .unwrap(); + } + }; + + let mut diags = String::new(); + + let mut diag_sink = DiagnosticSink::new(|diag| { + write!(diags, "{}: {}\n", diag.highlight_range(), diag.message()).unwrap(); + }); + + let ctx = LocationCtx::new(&db, file_id); + for node in source_file.syntax().descendants() { + if let Some(def) = ast::FunctionDef::cast(node.clone()) { + let fun = Function { + id: ctx.to_def(&def), + }; + let source_map = fun.body_source_map(&db); + let infer_result = fun.infer(&db); + + for diag in infer_result.diagnostics.iter() { + diag.add_to(&db, fun, &mut diag_sink); + } + + infer_def(infer_result, source_map); + } + } + + drop(diag_sink); + + acc.truncate(acc.trim_end().len()); + diags.truncate(diags.trim_end().len()); + [diags, acc].join("\n").trim().to_string() +} + +fn ellipsize(mut text: String, max_len: usize) -> String { + if text.len() <= max_len { + return text; + } + let ellipsis = "..."; + let e_len = ellipsis.len(); + let mut prefix_len = (max_len - e_len) / 2; + while !text.is_char_boundary(prefix_len) { + prefix_len += 1; + } + let mut suffix_len = max_len - e_len - prefix_len; + while !text.is_char_boundary(text.len() - suffix_len) { + suffix_len += 1; + } + text.replace_range(prefix_len..text.len() - suffix_len, ellipsis); + text +} diff --git a/crates/mun_hir/src/type_ref.rs b/crates/mun_hir/src/type_ref.rs index 7d742363b..16a30b53d 100644 --- a/crates/mun_hir/src/type_ref.rs +++ b/crates/mun_hir/src/type_ref.rs @@ -16,6 +16,7 @@ impl_arena_id!(TypeRefId); #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub enum TypeRef { Path(Path), + Never, Empty, Error, } @@ -63,15 +64,15 @@ impl TypeRefBuilder { id } - pub fn from_node_opt(&mut self, node: Option<&ast::TypeRef>) -> TypeRefId { + pub fn alloc_from_node_opt(&mut self, node: Option<&ast::TypeRef>) -> TypeRefId { if let Some(node) = node { - self.from_node(node) + self.alloc_from_node(node) } else { self.error() } } - pub fn from_node(&mut self, node: &ast::TypeRef) -> TypeRefId { + pub fn alloc_from_node(&mut self, node: &ast::TypeRef) -> TypeRefId { use mun_syntax::ast::TypeRefKind::*; let ptr = AstPtr::new(node); let type_ref = match node.kind() { @@ -80,6 +81,7 @@ impl TypeRefBuilder { .and_then(Path::from_ast) .map(TypeRef::Path) .unwrap_or(TypeRef::Error), + NeverType(_) => TypeRef::Never, }; self.alloc_type_ref(type_ref, ptr) } diff --git a/crates/mun_runtime/src/macros.rs b/crates/mun_runtime/src/macros.rs index 5cc79a88f..c3ea35b86 100644 --- a/crates/mun_runtime/src/macros.rs +++ b/crates/mun_runtime/src/macros.rs @@ -34,6 +34,7 @@ macro_rules! invoke_fn_impl { impl<'r, 's, $($T: Reflection,)* Output: Reflection> $ErrName<'r, 's, $($T,)* Output> { /// Constructs a new invocation error. + #[allow(clippy::too_many_arguments)] pub fn new(err_msg: String, runtime: &'r mut MunRuntime, function_name: &'s str, $($Arg: $T),*) -> Self { Self { msg: err_msg, @@ -77,6 +78,7 @@ macro_rules! invoke_fn_impl { /// /// If an error occurs when invoking the method, an error message is logged. The /// runtime continues looping until the cause of the error has been resolved. + #[allow(clippy::too_many_arguments)] pub fn $FnName<'r, 's, $($T: Reflection,)* Output: Reflection>( runtime: &'r mut MunRuntime, function_name: &'s str, diff --git a/crates/mun_runtime/src/test.rs b/crates/mun_runtime/src/test.rs index 0bfa98498..6388d8a9f 100644 --- a/crates/mun_runtime/src/test.rs +++ b/crates/mun_runtime/src/test.rs @@ -207,3 +207,31 @@ fn booleans() { false ); } + +#[test] +fn fibonacci() { + let compile_result = compile( + r#" + fn fibonacci(n:int):int { + if n <= 1 { + n + } else { + fibonacci(n-1) + fibonacci(n-2) + } + } + "#, + ); + let mut runtime = compile_result.new_runtime(); + assert_eq!( + MunRuntime::invoke_fn1::(&mut runtime, "fibonacci", 5).unwrap(), + 5 + ); + assert_eq!( + MunRuntime::invoke_fn1::(&mut runtime, "fibonacci", 11).unwrap(), + 89 + ); + assert_eq!( + MunRuntime::invoke_fn1::(&mut runtime, "fibonacci", 16).unwrap(), + 987 + ); +} diff --git a/crates/mun_runtime/tests/data/main.dll b/crates/mun_runtime/tests/data/main.dll deleted file mode 100644 index 12dd3df75..000000000 Binary files a/crates/mun_runtime/tests/data/main.dll and /dev/null differ diff --git a/crates/mun_runtime_capi/src/lib.rs b/crates/mun_runtime_capi/src/lib.rs index 5c4b06dab..8c0a9864e 100644 --- a/crates/mun_runtime_capi/src/lib.rs +++ b/crates/mun_runtime_capi/src/lib.rs @@ -7,18 +7,21 @@ use std::os::raw::c_char; pub struct RuntimeHandle(*mut c_void); #[no_mangle] -pub extern "C" fn create_runtime(library_path: *const c_char, handle: *mut RuntimeHandle) -> u64 /* error */ +pub unsafe extern "C" fn create_runtime( + library_path: *const c_char, + handle: *mut RuntimeHandle, +) -> u64 /* error */ { if library_path.is_null() { return 1; } - let library_path = match unsafe { CStr::from_ptr(library_path) }.to_str() { + let library_path = match CStr::from_ptr(library_path).to_str() { Ok(path) => path, Err(_) => return 2, }; - let handle = match unsafe { handle.as_mut() } { + let handle = match handle.as_mut() { Some(handle) => handle, None => return 3, }; @@ -40,28 +43,28 @@ pub extern "C" fn destroy_runtime(handle: RuntimeHandle) { } #[no_mangle] -pub extern "C" fn runtime_get_function_info( +pub unsafe extern "C" fn runtime_get_function_info( handle: RuntimeHandle, fn_name: *const c_char, has_fn_info: *mut bool, fn_info: *mut FunctionInfo, ) -> u64 /* error */ { - let runtime = match unsafe { (handle.0 as *mut MunRuntime).as_ref() } { + let runtime = match (handle.0 as *mut MunRuntime).as_ref() { Some(runtime) => runtime, None => return 1, }; - let fn_name = match unsafe { CStr::from_ptr(fn_name) }.to_str() { + let fn_name = match CStr::from_ptr(fn_name).to_str() { Ok(name) => name, Err(_) => return 2, }; - let has_fn_info = match unsafe { has_fn_info.as_mut() } { + let has_fn_info = match has_fn_info.as_mut() { Some(has_info) => has_info, None => return 3, }; - let fn_info = match unsafe { fn_info.as_mut() } { + let fn_info = match fn_info.as_mut() { Some(info) => info, None => return 4, }; @@ -78,13 +81,14 @@ pub extern "C" fn runtime_get_function_info( } #[no_mangle] -pub extern "C" fn runtime_update(handle: RuntimeHandle, updated: *mut bool) -> u64 /* error */ { - let runtime = match unsafe { (handle.0 as *mut MunRuntime).as_mut() } { +pub unsafe extern "C" fn runtime_update(handle: RuntimeHandle, updated: *mut bool) -> u64 /* error */ +{ + let runtime = match (handle.0 as *mut MunRuntime).as_mut() { Some(runtime) => runtime, None => return 1, }; - let updated = match unsafe { updated.as_mut() } { + let updated = match updated.as_mut() { Some(updated) => updated, None => return 2, }; diff --git a/crates/mun_syntax/src/ast/expr_extensions.rs b/crates/mun_syntax/src/ast/expr_extensions.rs index 3fffe0ed1..c4970c296 100644 --- a/crates/mun_syntax/src/ast/expr_extensions.rs +++ b/crates/mun_syntax/src/ast/expr_extensions.rs @@ -1,5 +1,5 @@ use super::{children, BinExpr}; -use crate::ast::Literal; +use crate::ast::{child_opt, AstChildren, Literal}; use crate::{ ast, AstNode, SyntaxKind::{self, *}, @@ -132,3 +132,29 @@ impl Literal { } } } + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ElseBranch { + Block(ast::BlockExpr), + IfExpr(ast::IfExpr), +} + +impl ast::IfExpr { + pub fn then_branch(&self) -> Option { + self.blocks().nth(0) + } + pub fn else_branch(&self) -> Option { + let res = match self.blocks().nth(1) { + Some(block) => ElseBranch::Block(block), + None => { + let elif: ast::IfExpr = child_opt(self)?; + ElseBranch::IfExpr(elif) + } + }; + Some(res) + } + + fn blocks(&self) -> AstChildren { + children(self) + } +} diff --git a/crates/mun_syntax/src/ast/extensions.rs b/crates/mun_syntax/src/ast/extensions.rs index ad09ccf1b..d0d408a4c 100644 --- a/crates/mun_syntax/src/ast/extensions.rs +++ b/crates/mun_syntax/src/ast/extensions.rs @@ -31,14 +31,14 @@ impl ast::FunctionDef { let start = fn_kw .map(|kw| kw.start()) - .unwrap_or(self.syntax.text_range().start()); + .unwrap_or_else(|| self.syntax.text_range().start()); let end = ret_type .map(|p| p.end()) - .or(param_list.map(|name| name.end())) - .or(name.map(|name| name.end())) - .or(fn_kw.map(|kw| kw.end())) - .unwrap_or(self.syntax().text_range().end()); + .or_else(|| param_list.map(|name| name.end())) + .or_else(|| name.map(|name| name.end())) + .or_else(|| fn_kw.map(|kw| kw.end())) + .unwrap_or_else(|| self.syntax().text_range().end()); TextRange::from_to(start, end) } diff --git a/crates/mun_syntax/src/ast/generated.rs b/crates/mun_syntax/src/ast/generated.rs index 16f2f4a4f..966bbb124 100644 --- a/crates/mun_syntax/src/ast/generated.rs +++ b/crates/mun_syntax/src/ast/generated.rs @@ -103,23 +103,23 @@ impl BindPat { } } -// Block +// BlockExpr #[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct Block { +pub struct BlockExpr { pub(crate) syntax: SyntaxNode, } -impl AstNode for Block { +impl AstNode for BlockExpr { fn can_cast(kind: SyntaxKind) -> bool { match kind { - BLOCK => true, + BLOCK_EXPR => true, _ => false, } } fn cast(syntax: SyntaxNode) -> Option { if Self::can_cast(syntax.kind()) { - Some(Block { syntax }) + Some(BlockExpr { syntax }) } else { None } @@ -128,7 +128,7 @@ impl AstNode for Block { &self.syntax } } -impl Block { +impl BlockExpr { pub fn statements(&self) -> impl Iterator { super::children(self) } @@ -170,6 +170,41 @@ impl CallExpr { } } +// Condition + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Condition { + pub(crate) syntax: SyntaxNode, +} + +impl AstNode for Condition { + fn can_cast(kind: SyntaxKind) -> bool { + match kind { + CONDITION => true, + _ => false, + } + } + fn cast(syntax: SyntaxNode) -> Option { + if Self::can_cast(syntax.kind()) { + Some(Condition { syntax }) + } else { + None + } + } + fn syntax(&self) -> &SyntaxNode { + &self.syntax + } +} +impl Condition { + pub fn pat(&self) -> Option { + super::child_opt(self) + } + + pub fn expr(&self) -> Option { + super::child_opt(self) + } +} + // Expr #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -180,7 +215,8 @@ pub struct Expr { impl AstNode for Expr { fn can_cast(kind: SyntaxKind) -> bool { match kind { - LITERAL | PREFIX_EXPR | PATH_EXPR | BIN_EXPR | PAREN_EXPR | CALL_EXPR => true, + LITERAL | PREFIX_EXPR | PATH_EXPR | BIN_EXPR | PAREN_EXPR | CALL_EXPR | IF_EXPR + | BLOCK_EXPR => true, _ => false, } } @@ -203,6 +239,8 @@ pub enum ExprKind { BinExpr(BinExpr), ParenExpr(ParenExpr), CallExpr(CallExpr), + IfExpr(IfExpr), + BlockExpr(BlockExpr), } impl From for Expr { fn from(n: Literal) -> Expr { @@ -234,6 +272,16 @@ impl From for Expr { Expr { syntax: n.syntax } } } +impl From for Expr { + fn from(n: IfExpr) -> Expr { + Expr { syntax: n.syntax } + } +} +impl From for Expr { + fn from(n: BlockExpr) -> Expr { + Expr { syntax: n.syntax } + } +} impl Expr { pub fn kind(&self) -> ExprKind { @@ -244,6 +292,8 @@ impl Expr { BIN_EXPR => ExprKind::BinExpr(BinExpr::cast(self.syntax.clone()).unwrap()), PAREN_EXPR => ExprKind::ParenExpr(ParenExpr::cast(self.syntax.clone()).unwrap()), CALL_EXPR => ExprKind::CallExpr(CallExpr::cast(self.syntax.clone()).unwrap()), + IF_EXPR => ExprKind::IfExpr(IfExpr::cast(self.syntax.clone()).unwrap()), + BLOCK_EXPR => ExprKind::BlockExpr(BlockExpr::cast(self.syntax.clone()).unwrap()), _ => unreachable!(), } } @@ -315,7 +365,7 @@ impl FunctionDef { super::child_opt(self) } - pub fn body(&self) -> Option { + pub fn body(&self) -> Option { super::child_opt(self) } @@ -324,6 +374,37 @@ impl FunctionDef { } } +// IfExpr + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct IfExpr { + pub(crate) syntax: SyntaxNode, +} + +impl AstNode for IfExpr { + fn can_cast(kind: SyntaxKind) -> bool { + match kind { + IF_EXPR => true, + _ => false, + } + } + fn cast(syntax: SyntaxNode) -> Option { + if Self::can_cast(syntax.kind()) { + Some(IfExpr { syntax }) + } else { + None + } + } + fn syntax(&self) -> &SyntaxNode { + &self.syntax + } +} +impl IfExpr { + pub fn condition(&self) -> Option { + super::child_opt(self) + } +} + // LetStmt #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -489,6 +570,33 @@ impl AstNode for NameRef { } impl NameRef {} +// NeverType + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct NeverType { + pub(crate) syntax: SyntaxNode, +} + +impl AstNode for NeverType { + fn can_cast(kind: SyntaxKind) -> bool { + match kind { + NEVER_TYPE => true, + _ => false, + } + } + fn cast(syntax: SyntaxNode) -> Option { + if Self::can_cast(syntax.kind()) { + Some(NeverType { syntax }) + } else { + None + } + } + fn syntax(&self) -> &SyntaxNode { + &self.syntax + } +} +impl NeverType {} + // Param #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -947,7 +1055,7 @@ pub struct TypeRef { impl AstNode for TypeRef { fn can_cast(kind: SyntaxKind) -> bool { match kind { - PATH_TYPE => true, + PATH_TYPE | NEVER_TYPE => true, _ => false, } } @@ -965,17 +1073,24 @@ impl AstNode for TypeRef { #[derive(Debug, Clone, PartialEq, Eq)] pub enum TypeRefKind { PathType(PathType), + NeverType(NeverType), } impl From for TypeRef { fn from(n: PathType) -> TypeRef { TypeRef { syntax: n.syntax } } } +impl From for TypeRef { + fn from(n: NeverType) -> TypeRef { + TypeRef { syntax: n.syntax } + } +} impl TypeRef { pub fn kind(&self) -> TypeRefKind { match self.syntax.kind() { PATH_TYPE => TypeRefKind::PathType(PathType::cast(self.syntax.clone()).unwrap()), + NEVER_TYPE => TypeRefKind::NeverType(NeverType::cast(self.syntax.clone()).unwrap()), _ => unreachable!(), } } diff --git a/crates/mun_syntax/src/grammar.ron b/crates/mun_syntax/src/grammar.ron index 3b95ddddc..c69f6086a 100644 --- a/crates/mun_syntax/src/grammar.ron +++ b/crates/mun_syntax/src/grammar.ron @@ -64,13 +64,13 @@ Grammar( "in", // "local", // We use let "nil", - //"not", // We use ! - "or", + // "not", // We use ! + // "or", "self", "super", // "repeat", // Not supported "return", - "then", + // "then", // Not supported "true", // "until", // Not supported "while", @@ -82,7 +82,8 @@ Grammar( "public", "protected", "private", - "export" + "export", + "never", ], literals: [ "INT_NUMBER", @@ -106,6 +107,7 @@ Grammar( "PARAM", "PATH_TYPE", + "NEVER_TYPE", "LET_STMT", "EXPR_STMT", @@ -116,6 +118,9 @@ Grammar( "BIN_EXPR", "PAREN_EXPR", "CALL_EXPR", + "IF_EXPR", + "BLOCK_EXPR", + "CONDITION", "BIND_PAT", "PLACEHOLDER_PAT", @@ -125,8 +130,6 @@ Grammar( "NAME", "NAME_REF", - "BLOCK", - "PATH", "PATH_SEGMENT", ], @@ -144,7 +147,7 @@ Grammar( "VisibilityOwner", "DocCommentsOwner", ], - options: [ "ParamList", ["body", "Block"], "RetType" ], + options: [ "ParamList", ["body", "BlockExpr"], "RetType" ], ), "RetType": (options: ["TypeRef"]), "ParamList": ( @@ -167,6 +170,10 @@ Grammar( "TypeAscriptionOwner", ] ), + "Condition": ( + options: [ "Pat", "Expr" ] + ), + "ExprStmt": ( options: [ ["expr", "Expr"] ] ), @@ -183,6 +190,9 @@ Grammar( traits: ["ArgListOwner"], options: [ "Expr" ], ), + "IfExpr": ( + options: [ "Condition" ] + ), "ArgList": ( collections: [ ["args", "Expr"] @@ -196,19 +206,23 @@ Grammar( "PathExpr", "BinExpr", "ParenExpr", - "CallExpr" + "CallExpr", + "IfExpr", + "BlockExpr", ] ), "Name": (), "NameRef": (), "PathType": (options: ["Path"]), + "NeverType": (), "TypeRef": ( enum: [ - "PathType" + "PathType", + "NeverType", ] ), - "Block": ( + "BlockExpr": ( options: [ "Expr" ], collections: [ ["statements", "Stmt"], @@ -236,4 +250,4 @@ Grammar( ], ), } -) \ No newline at end of file +) diff --git a/crates/mun_syntax/src/lib.rs b/crates/mun_syntax/src/lib.rs index 83f01377a..4c893547f 100644 --- a/crates/mun_syntax/src/lib.rs +++ b/crates/mun_syntax/src/lib.rs @@ -69,7 +69,7 @@ impl Parse { } impl Parse { - pub fn to_syntax(self) -> Parse { + pub fn into_syntax(self) -> Parse { Parse { green: self.green, errors: self.errors, diff --git a/crates/mun_syntax/src/parsing/grammar/expressions.rs b/crates/mun_syntax/src/parsing/grammar/expressions.rs index adb1133f7..70fab8439 100644 --- a/crates/mun_syntax/src/parsing/grammar/expressions.rs +++ b/crates/mun_syntax/src/parsing/grammar/expressions.rs @@ -27,11 +27,16 @@ pub(crate) fn block(p: &mut Parser) { p.error("expected a block"); return; } + block_expr(p); +} + +fn block_expr(p: &mut Parser) -> CompletedMarker { + assert!(p.at(T!['{'])); let m = p.start(); p.bump(T!['{']); expr_block_contents(p); p.expect(T!['}']); - m.complete(p, BLOCK); + m.complete(p, BLOCK_EXPR) } /// Parses a general statement: (let, expr, etc.) @@ -191,6 +196,8 @@ fn atom_expr(p: &mut Parser) -> Option { let marker = match p.current() { T!['('] => paren_expr(p), + T!['{'] => block_expr(p), + T![if] => if_expr(p), _ => { p.error_recover("expected expression", EXPR_RECOVERY_SET); return None; @@ -222,3 +229,26 @@ fn paren_expr(p: &mut Parser) -> CompletedMarker { p.expect(T![')']); m.complete(p, PAREN_EXPR) } + +fn if_expr(p: &mut Parser) -> CompletedMarker { + assert!(p.at(T![if])); + let m = p.start(); + p.bump(T![if]); + cond(p); + block(p); + if p.at(T![else]) { + p.bump(T![else]); + if p.at(T![if]) { + if_expr(p); + } else { + block(p); + } + } + m.complete(p, IF_EXPR) +} + +fn cond(p: &mut Parser) { + let m = p.start(); + expr(p); + m.complete(p, CONDITION); +} diff --git a/crates/mun_syntax/src/parsing/grammar/types.rs b/crates/mun_syntax/src/parsing/grammar/types.rs index 7b5aae4b1..87f02d358 100644 --- a/crates/mun_syntax/src/parsing/grammar/types.rs +++ b/crates/mun_syntax/src/parsing/grammar/types.rs @@ -9,6 +9,7 @@ pub(super) fn ascription(p: &mut Parser) { pub(super) fn type_(p: &mut Parser) { match p.current() { + T![never] => never_type(p), _ if paths::is_path_start(p) => path_type(p), _ => { p.error_recover("expected type", TYPE_RECOVERY_SET); @@ -21,3 +22,10 @@ pub(super) fn path_type(p: &mut Parser) { paths::type_path(p); m.complete(p, PATH_TYPE); } + +fn never_type(p: &mut Parser) { + assert!(p.at(T![never])); + let m = p.start(); + p.bump(T![never]); + m.complete(p, NEVER_TYPE); +} diff --git a/crates/mun_syntax/src/parsing/parser.rs b/crates/mun_syntax/src/parsing/parser.rs index c8bf6c967..401398ae8 100644 --- a/crates/mun_syntax/src/parsing/parser.rs +++ b/crates/mun_syntax/src/parsing/parser.rs @@ -84,13 +84,13 @@ impl<'t> Parser<'t> { } fn at_composite2(&self, n: usize, k1: SyntaxKind, k2: SyntaxKind) -> bool { - let t1 = self.token_source.lookahead_nth(n + 0); + let t1 = self.token_source.lookahead_nth(n); let t2 = self.token_source.lookahead_nth(n + 1); t1.kind == k1 && t1.is_jointed_to_next && t2.kind == k2 } fn at_composite3(&self, n: usize, k1: SyntaxKind, k2: SyntaxKind, k3: SyntaxKind) -> bool { - let t1 = self.token_source.lookahead_nth(n + 0); + let t1 = self.token_source.lookahead_nth(n); let t2 = self.token_source.lookahead_nth(n + 1); let t3 = self.token_source.lookahead_nth(n + 2); (t1.kind == k1 && t1.is_jointed_to_next) diff --git a/crates/mun_syntax/src/parsing/text_token_source.rs b/crates/mun_syntax/src/parsing/text_token_source.rs index a82a30200..447639f2a 100644 --- a/crates/mun_syntax/src/parsing/text_token_source.rs +++ b/crates/mun_syntax/src/parsing/text_token_source.rs @@ -49,7 +49,7 @@ impl<'t> TokenSource for TextTokenSource<'t> { fn is_keyword(&self, kw: &str) -> bool { let pos = self.curr.1; - if !(pos < self.tokens.len()) { + if pos >= self.tokens.len() { return false; } let range = TextRange::offset_len(self.start_offsets[pos], self.tokens[pos].len); diff --git a/crates/mun_syntax/src/syntax_kind/generated.rs b/crates/mun_syntax/src/syntax_kind/generated.rs index b02420e2b..2665c47dd 100644 --- a/crates/mun_syntax/src/syntax_kind/generated.rs +++ b/crates/mun_syntax/src/syntax_kind/generated.rs @@ -61,11 +61,9 @@ pub enum SyntaxKind { IF_KW, IN_KW, NIL_KW, - OR_KW, SELF_KW, SUPER_KW, RETURN_KW, - THEN_KW, TRUE_KW, WHILE_KW, LET_KW, @@ -75,6 +73,7 @@ pub enum SyntaxKind { PROTECTED_KW, PRIVATE_KW, EXPORT_KW, + NEVER_KW, INT_NUMBER, FLOAT_NUMBER, STRING, @@ -89,6 +88,7 @@ pub enum SyntaxKind { PARAM_LIST, PARAM, PATH_TYPE, + NEVER_TYPE, LET_STMT, EXPR_STMT, PATH_EXPR, @@ -97,12 +97,14 @@ pub enum SyntaxKind { BIN_EXPR, PAREN_EXPR, CALL_EXPR, + IF_EXPR, + BLOCK_EXPR, + CONDITION, BIND_PAT, PLACEHOLDER_PAT, ARG_LIST, NAME, NAME_REF, - BLOCK, PATH, PATH_SEGMENT, // Technical kind so that we can cast from u16 safely @@ -159,11 +161,9 @@ macro_rules! T { (if) => { $crate::SyntaxKind::IF_KW }; (in) => { $crate::SyntaxKind::IN_KW }; (nil) => { $crate::SyntaxKind::NIL_KW }; - (or) => { $crate::SyntaxKind::OR_KW }; (self) => { $crate::SyntaxKind::SELF_KW }; (super) => { $crate::SyntaxKind::SUPER_KW }; (return) => { $crate::SyntaxKind::RETURN_KW }; - (then) => { $crate::SyntaxKind::THEN_KW }; (true) => { $crate::SyntaxKind::TRUE_KW }; (while) => { $crate::SyntaxKind::WHILE_KW }; (let) => { $crate::SyntaxKind::LET_KW }; @@ -173,6 +173,7 @@ macro_rules! T { (protected) => { $crate::SyntaxKind::PROTECTED_KW }; (private) => { $crate::SyntaxKind::PRIVATE_KW }; (export) => { $crate::SyntaxKind::EXPORT_KW }; + (never) => { $crate::SyntaxKind::NEVER_KW }; } impl From for SyntaxKind { @@ -201,11 +202,9 @@ impl SyntaxKind { | IF_KW | IN_KW | NIL_KW - | OR_KW | SELF_KW | SUPER_KW | RETURN_KW - | THEN_KW | TRUE_KW | WHILE_KW | LET_KW @@ -215,6 +214,7 @@ impl SyntaxKind { | PROTECTED_KW | PRIVATE_KW | EXPORT_KW + | NEVER_KW => true, _ => false } @@ -321,11 +321,9 @@ impl SyntaxKind { IF_KW => &SyntaxInfo { name: "IF_KW" }, IN_KW => &SyntaxInfo { name: "IN_KW" }, NIL_KW => &SyntaxInfo { name: "NIL_KW" }, - OR_KW => &SyntaxInfo { name: "OR_KW" }, SELF_KW => &SyntaxInfo { name: "SELF_KW" }, SUPER_KW => &SyntaxInfo { name: "SUPER_KW" }, RETURN_KW => &SyntaxInfo { name: "RETURN_KW" }, - THEN_KW => &SyntaxInfo { name: "THEN_KW" }, TRUE_KW => &SyntaxInfo { name: "TRUE_KW" }, WHILE_KW => &SyntaxInfo { name: "WHILE_KW" }, LET_KW => &SyntaxInfo { name: "LET_KW" }, @@ -335,6 +333,7 @@ impl SyntaxKind { PROTECTED_KW => &SyntaxInfo { name: "PROTECTED_KW" }, PRIVATE_KW => &SyntaxInfo { name: "PRIVATE_KW" }, EXPORT_KW => &SyntaxInfo { name: "EXPORT_KW" }, + NEVER_KW => &SyntaxInfo { name: "NEVER_KW" }, INT_NUMBER => &SyntaxInfo { name: "INT_NUMBER" }, FLOAT_NUMBER => &SyntaxInfo { name: "FLOAT_NUMBER" }, STRING => &SyntaxInfo { name: "STRING" }, @@ -349,6 +348,7 @@ impl SyntaxKind { PARAM_LIST => &SyntaxInfo { name: "PARAM_LIST" }, PARAM => &SyntaxInfo { name: "PARAM" }, PATH_TYPE => &SyntaxInfo { name: "PATH_TYPE" }, + NEVER_TYPE => &SyntaxInfo { name: "NEVER_TYPE" }, LET_STMT => &SyntaxInfo { name: "LET_STMT" }, EXPR_STMT => &SyntaxInfo { name: "EXPR_STMT" }, PATH_EXPR => &SyntaxInfo { name: "PATH_EXPR" }, @@ -357,12 +357,14 @@ impl SyntaxKind { BIN_EXPR => &SyntaxInfo { name: "BIN_EXPR" }, PAREN_EXPR => &SyntaxInfo { name: "PAREN_EXPR" }, CALL_EXPR => &SyntaxInfo { name: "CALL_EXPR" }, + IF_EXPR => &SyntaxInfo { name: "IF_EXPR" }, + BLOCK_EXPR => &SyntaxInfo { name: "BLOCK_EXPR" }, + CONDITION => &SyntaxInfo { name: "CONDITION" }, BIND_PAT => &SyntaxInfo { name: "BIND_PAT" }, PLACEHOLDER_PAT => &SyntaxInfo { name: "PLACEHOLDER_PAT" }, ARG_LIST => &SyntaxInfo { name: "ARG_LIST" }, NAME => &SyntaxInfo { name: "NAME" }, NAME_REF => &SyntaxInfo { name: "NAME_REF" }, - BLOCK => &SyntaxInfo { name: "BLOCK" }, PATH => &SyntaxInfo { name: "PATH" }, PATH_SEGMENT => &SyntaxInfo { name: "PATH_SEGMENT" }, TOMBSTONE => &SyntaxInfo { name: "TOMBSTONE" }, @@ -383,11 +385,9 @@ impl SyntaxKind { "if" => IF_KW, "in" => IN_KW, "nil" => NIL_KW, - "or" => OR_KW, "self" => SELF_KW, "super" => SUPER_KW, "return" => RETURN_KW, - "then" => THEN_KW, "true" => TRUE_KW, "while" => WHILE_KW, "let" => LET_KW, @@ -397,6 +397,7 @@ impl SyntaxKind { "protected" => PROTECTED_KW, "private" => PRIVATE_KW, "export" => EXPORT_KW, + "never" => NEVER_KW, _ => return None, }; Some(kw) diff --git a/crates/mun_syntax/src/tests/lexer.rs b/crates/mun_syntax/src/tests/lexer.rs index 7897d17cd..c61e30521 100644 --- a/crates/mun_syntax/src/tests/lexer.rs +++ b/crates/mun_syntax/src/tests/lexer.rs @@ -108,9 +108,9 @@ fn strings() { fn keywords() { lex_snapshot( r#" - and break do else false for fn if in nil or - return then true while let mut class public protected - private + and break do else false for fn if in nil + return true while let mut class public protected + private never "#, ) } diff --git a/crates/mun_syntax/src/tests/parser.rs b/crates/mun_syntax/src/tests/parser.rs index 61c289660..fe06be4ea 100644 --- a/crates/mun_syntax/src/tests/parser.rs +++ b/crates/mun_syntax/src/tests/parser.rs @@ -27,7 +27,7 @@ fn function() { // Comment that belongs to the function fn a() {} fn b(value:number) {} - export fn c() {} + export fn c():never {} fn b(value:number):number {}"#, ); } @@ -139,3 +139,28 @@ fn compare_operands() { "#, ) } + +#[test] +fn if_expr() { + ok_snapshot_test( + r#" + fn bar() { + if true {}; + if true {} else {}; + if true {} else if false {} else {}; + if {true} {} else {} + } + "#, + ); +} + +#[test] +fn block_expr() { + ok_snapshot_test( + r#" + fn bar() { + {3} + } + "#, + ); +} diff --git a/crates/mun_syntax/src/tests/snapshots/lexer__keywords.snap b/crates/mun_syntax/src/tests/snapshots/lexer__keywords.snap index d10f45b48..abb9d9626 100644 --- a/crates/mun_syntax/src/tests/snapshots/lexer__keywords.snap +++ b/crates/mun_syntax/src/tests/snapshots/lexer__keywords.snap @@ -1,6 +1,6 @@ --- source: crates/mun_syntax/src/tests/lexer.rs -expression: "and break do else false for fn if in nil or\nreturn then true while let mut class public protected\nprivate" +expression: "and break do else false for fn if in nil\nreturn true while let mut class public protected\nprivate never" --- AND_KW 3 "and" WHITESPACE 1 " " @@ -21,13 +21,9 @@ WHITESPACE 1 " " IN_KW 2 "in" WHITESPACE 1 " " NIL_KW 3 "nil" -WHITESPACE 1 " " -OR_KW 2 "or" WHITESPACE 1 "\n" RETURN_KW 6 "return" WHITESPACE 1 " " -THEN_KW 4 "then" -WHITESPACE 1 " " TRUE_KW 4 "true" WHITESPACE 1 " " WHILE_KW 5 "while" @@ -43,4 +39,6 @@ WHITESPACE 1 " " PROTECTED_KW 9 "protected" WHITESPACE 1 "\n" PRIVATE_KW 7 "private" +WHITESPACE 1 " " +NEVER_KW 5 "never" diff --git a/crates/mun_syntax/src/tests/snapshots/parser__binary_expr.snap b/crates/mun_syntax/src/tests/snapshots/parser__binary_expr.snap index 08696b2bd..ccf685af7 100644 --- a/crates/mun_syntax/src/tests/snapshots/parser__binary_expr.snap +++ b/crates/mun_syntax/src/tests/snapshots/parser__binary_expr.snap @@ -12,7 +12,7 @@ SOURCE_FILE@[0; 51) L_PAREN@[6; 7) "(" R_PAREN@[7; 8) ")" WHITESPACE@[8; 9) " " - BLOCK@[9; 51) + BLOCK_EXPR@[9; 51) L_CURLY@[9; 10) "{" WHITESPACE@[10; 15) "\n " LET_STMT@[15; 28) diff --git a/crates/mun_syntax/src/tests/snapshots/parser__block.snap b/crates/mun_syntax/src/tests/snapshots/parser__block.snap index 7b3c17031..d49a9ca5c 100644 --- a/crates/mun_syntax/src/tests/snapshots/parser__block.snap +++ b/crates/mun_syntax/src/tests/snapshots/parser__block.snap @@ -12,7 +12,7 @@ SOURCE_FILE@[0; 56) L_PAREN@[6; 7) "(" R_PAREN@[7; 8) ")" WHITESPACE@[8; 9) " " - BLOCK@[9; 56) + BLOCK_EXPR@[9; 56) L_CURLY@[9; 10) "{" WHITESPACE@[10; 15) "\n " LET_STMT@[15; 21) diff --git a/crates/mun_syntax/src/tests/snapshots/parser__block_expr.snap b/crates/mun_syntax/src/tests/snapshots/parser__block_expr.snap new file mode 100644 index 000000000..f9ac48873 --- /dev/null +++ b/crates/mun_syntax/src/tests/snapshots/parser__block_expr.snap @@ -0,0 +1,25 @@ +--- +source: crates/mun_syntax/src/tests/parser.rs +expression: "fn bar() {\n {3}\n}" +--- +SOURCE_FILE@[0; 20) + FUNCTION_DEF@[0; 20) + FN_KW@[0; 2) "fn" + WHITESPACE@[2; 3) " " + NAME@[3; 6) + IDENT@[3; 6) "bar" + PARAM_LIST@[6; 8) + L_PAREN@[6; 7) "(" + R_PAREN@[7; 8) ")" + WHITESPACE@[8; 9) " " + BLOCK_EXPR@[9; 20) + L_CURLY@[9; 10) "{" + WHITESPACE@[10; 15) "\n " + BLOCK_EXPR@[15; 18) + L_CURLY@[15; 16) "{" + LITERAL@[16; 17) + INT_NUMBER@[16; 17) "3" + R_CURLY@[17; 18) "}" + WHITESPACE@[18; 19) "\n" + R_CURLY@[19; 20) "}" + diff --git a/crates/mun_syntax/src/tests/snapshots/parser__compare_operands.snap b/crates/mun_syntax/src/tests/snapshots/parser__compare_operands.snap index 01e0144f2..8875f2989 100644 --- a/crates/mun_syntax/src/tests/snapshots/parser__compare_operands.snap +++ b/crates/mun_syntax/src/tests/snapshots/parser__compare_operands.snap @@ -12,7 +12,7 @@ SOURCE_FILE@[0; 149) L_PAREN@[7; 8) "(" R_PAREN@[8; 9) ")" WHITESPACE@[9; 10) " " - BLOCK@[10; 149) + BLOCK_EXPR@[10; 149) L_CURLY@[10; 11) "{" WHITESPACE@[11; 16) "\n " LET_STMT@[16; 29) diff --git a/crates/mun_syntax/src/tests/snapshots/parser__expression_statement.snap b/crates/mun_syntax/src/tests/snapshots/parser__expression_statement.snap index 83e3ad90e..9fc56467f 100644 --- a/crates/mun_syntax/src/tests/snapshots/parser__expression_statement.snap +++ b/crates/mun_syntax/src/tests/snapshots/parser__expression_statement.snap @@ -12,7 +12,7 @@ SOURCE_FILE@[0; 110) L_PAREN@[6; 7) "(" R_PAREN@[7; 8) ")" WHITESPACE@[8; 9) " " - BLOCK@[9; 110) + BLOCK_EXPR@[9; 110) L_CURLY@[9; 10) "{" WHITESPACE@[10; 15) "\n " LET_STMT@[15; 30) diff --git a/crates/mun_syntax/src/tests/snapshots/parser__function.snap b/crates/mun_syntax/src/tests/snapshots/parser__function.snap index e7233c4cb..4d2bfba54 100644 --- a/crates/mun_syntax/src/tests/snapshots/parser__function.snap +++ b/crates/mun_syntax/src/tests/snapshots/parser__function.snap @@ -1,8 +1,8 @@ --- source: crates/mun_syntax/src/tests/parser.rs -expression: "// Source file comment\n\n// Comment that belongs to the function\nfn a() {}\nfn b(value:number) {}\nexport fn c() {}\nfn b(value:number):number {}" +expression: "// Source file comment\n\n// Comment that belongs to the function\nfn a() {}\nfn b(value:number) {}\nexport fn c():never {}\nfn b(value:number):number {}" --- -SOURCE_FILE@[0; 141) +SOURCE_FILE@[0; 147) COMMENT@[0; 22) "// Source file comment" WHITESPACE@[22; 24) "\n\n" FUNCTION_DEF@[24; 73) @@ -16,7 +16,7 @@ SOURCE_FILE@[0; 141) L_PAREN@[68; 69) "(" R_PAREN@[69; 70) ")" WHITESPACE@[70; 71) " " - BLOCK@[71; 73) + BLOCK_EXPR@[71; 73) L_CURLY@[71; 72) "{" R_CURLY@[72; 73) "}" FUNCTION_DEF@[73; 95) @@ -39,10 +39,10 @@ SOURCE_FILE@[0; 141) IDENT@[85; 91) "number" R_PAREN@[91; 92) ")" WHITESPACE@[92; 93) " " - BLOCK@[93; 95) + BLOCK_EXPR@[93; 95) L_CURLY@[93; 94) "{" R_CURLY@[94; 95) "}" - FUNCTION_DEF@[95; 112) + FUNCTION_DEF@[95; 118) WHITESPACE@[95; 96) "\n" VISIBILITY@[96; 102) EXPORT_KW@[96; 102) "export" @@ -54,38 +54,42 @@ SOURCE_FILE@[0; 141) PARAM_LIST@[107; 109) L_PAREN@[107; 108) "(" R_PAREN@[108; 109) ")" - WHITESPACE@[109; 110) " " - BLOCK@[110; 112) - L_CURLY@[110; 111) "{" - R_CURLY@[111; 112) "}" - FUNCTION_DEF@[112; 141) - WHITESPACE@[112; 113) "\n" - FN_KW@[113; 115) "fn" + RET_TYPE@[109; 115) + COLON@[109; 110) ":" + NEVER_TYPE@[110; 115) + NEVER_KW@[110; 115) "never" WHITESPACE@[115; 116) " " - NAME@[116; 117) - IDENT@[116; 117) "b" - PARAM_LIST@[117; 131) - L_PAREN@[117; 118) "(" - PARAM@[118; 130) - BIND_PAT@[118; 123) - NAME@[118; 123) - IDENT@[118; 123) "value" - COLON@[123; 124) ":" - PATH_TYPE@[124; 130) - PATH@[124; 130) - PATH_SEGMENT@[124; 130) - NAME_REF@[124; 130) - IDENT@[124; 130) "number" - R_PAREN@[130; 131) ")" - RET_TYPE@[131; 138) - COLON@[131; 132) ":" - PATH_TYPE@[132; 138) - PATH@[132; 138) - PATH_SEGMENT@[132; 138) - NAME_REF@[132; 138) - IDENT@[132; 138) "number" - WHITESPACE@[138; 139) " " - BLOCK@[139; 141) - L_CURLY@[139; 140) "{" - R_CURLY@[140; 141) "}" + BLOCK_EXPR@[116; 118) + L_CURLY@[116; 117) "{" + R_CURLY@[117; 118) "}" + FUNCTION_DEF@[118; 147) + WHITESPACE@[118; 119) "\n" + FN_KW@[119; 121) "fn" + WHITESPACE@[121; 122) " " + NAME@[122; 123) + IDENT@[122; 123) "b" + PARAM_LIST@[123; 137) + L_PAREN@[123; 124) "(" + PARAM@[124; 136) + BIND_PAT@[124; 129) + NAME@[124; 129) + IDENT@[124; 129) "value" + COLON@[129; 130) ":" + PATH_TYPE@[130; 136) + PATH@[130; 136) + PATH_SEGMENT@[130; 136) + NAME_REF@[130; 136) + IDENT@[130; 136) "number" + R_PAREN@[136; 137) ")" + RET_TYPE@[137; 144) + COLON@[137; 138) ":" + PATH_TYPE@[138; 144) + PATH@[138; 144) + PATH_SEGMENT@[138; 144) + NAME_REF@[138; 144) + IDENT@[138; 144) "number" + WHITESPACE@[144; 145) " " + BLOCK_EXPR@[145; 147) + L_CURLY@[145; 146) "{" + R_CURLY@[146; 147) "}" diff --git a/crates/mun_syntax/src/tests/snapshots/parser__function_calls.snap b/crates/mun_syntax/src/tests/snapshots/parser__function_calls.snap index 0d05ada09..8e9d7e94e 100644 --- a/crates/mun_syntax/src/tests/snapshots/parser__function_calls.snap +++ b/crates/mun_syntax/src/tests/snapshots/parser__function_calls.snap @@ -22,7 +22,7 @@ SOURCE_FILE@[0; 52) IDENT@[9; 15) "number" R_PAREN@[15; 16) ")" WHITESPACE@[16; 17) " " - BLOCK@[17; 20) + BLOCK_EXPR@[17; 20) L_CURLY@[17; 18) "{" WHITESPACE@[18; 19) " " R_CURLY@[19; 20) "}" @@ -46,7 +46,7 @@ SOURCE_FILE@[0; 52) IDENT@[30; 36) "number" R_PAREN@[36; 37) ")" WHITESPACE@[37; 38) " " - BLOCK@[38; 52) + BLOCK_EXPR@[38; 52) L_CURLY@[38; 39) "{" WHITESPACE@[39; 42) "\n " CALL_EXPR@[42; 50) diff --git a/crates/mun_syntax/src/tests/snapshots/parser__if_expr.snap b/crates/mun_syntax/src/tests/snapshots/parser__if_expr.snap new file mode 100644 index 000000000..accca1415 --- /dev/null +++ b/crates/mun_syntax/src/tests/snapshots/parser__if_expr.snap @@ -0,0 +1,103 @@ +--- +source: crates/mun_syntax/src/tests/parser.rs +expression: "fn bar() {\n if true {};\n if true {} else {};\n if true {} else if false {} else {};\n if {true} {} else {}\n}" +--- +SOURCE_FILE@[0; 118) + FUNCTION_DEF@[0; 118) + FN_KW@[0; 2) "fn" + WHITESPACE@[2; 3) " " + NAME@[3; 6) + IDENT@[3; 6) "bar" + PARAM_LIST@[6; 8) + L_PAREN@[6; 7) "(" + R_PAREN@[7; 8) ")" + WHITESPACE@[8; 9) " " + BLOCK_EXPR@[9; 118) + L_CURLY@[9; 10) "{" + WHITESPACE@[10; 15) "\n " + EXPR_STMT@[15; 26) + IF_EXPR@[15; 25) + IF_KW@[15; 17) "if" + WHITESPACE@[17; 18) " " + CONDITION@[18; 22) + LITERAL@[18; 22) + TRUE_KW@[18; 22) "true" + WHITESPACE@[22; 23) " " + BLOCK_EXPR@[23; 25) + L_CURLY@[23; 24) "{" + R_CURLY@[24; 25) "}" + SEMI@[25; 26) ";" + WHITESPACE@[26; 31) "\n " + EXPR_STMT@[31; 50) + IF_EXPR@[31; 49) + IF_KW@[31; 33) "if" + WHITESPACE@[33; 34) " " + CONDITION@[34; 38) + LITERAL@[34; 38) + TRUE_KW@[34; 38) "true" + WHITESPACE@[38; 39) " " + BLOCK_EXPR@[39; 41) + L_CURLY@[39; 40) "{" + R_CURLY@[40; 41) "}" + WHITESPACE@[41; 42) " " + ELSE_KW@[42; 46) "else" + WHITESPACE@[46; 47) " " + BLOCK_EXPR@[47; 49) + L_CURLY@[47; 48) "{" + R_CURLY@[48; 49) "}" + SEMI@[49; 50) ";" + WHITESPACE@[50; 55) "\n " + EXPR_STMT@[55; 91) + IF_EXPR@[55; 90) + IF_KW@[55; 57) "if" + WHITESPACE@[57; 58) " " + CONDITION@[58; 62) + LITERAL@[58; 62) + TRUE_KW@[58; 62) "true" + WHITESPACE@[62; 63) " " + BLOCK_EXPR@[63; 65) + L_CURLY@[63; 64) "{" + R_CURLY@[64; 65) "}" + WHITESPACE@[65; 66) " " + ELSE_KW@[66; 70) "else" + WHITESPACE@[70; 71) " " + IF_EXPR@[71; 90) + IF_KW@[71; 73) "if" + WHITESPACE@[73; 74) " " + CONDITION@[74; 79) + LITERAL@[74; 79) + FALSE_KW@[74; 79) "false" + WHITESPACE@[79; 80) " " + BLOCK_EXPR@[80; 82) + L_CURLY@[80; 81) "{" + R_CURLY@[81; 82) "}" + WHITESPACE@[82; 83) " " + ELSE_KW@[83; 87) "else" + WHITESPACE@[87; 88) " " + BLOCK_EXPR@[88; 90) + L_CURLY@[88; 89) "{" + R_CURLY@[89; 90) "}" + SEMI@[90; 91) ";" + WHITESPACE@[91; 96) "\n " + IF_EXPR@[96; 116) + IF_KW@[96; 98) "if" + WHITESPACE@[98; 99) " " + CONDITION@[99; 105) + BLOCK_EXPR@[99; 105) + L_CURLY@[99; 100) "{" + LITERAL@[100; 104) + TRUE_KW@[100; 104) "true" + R_CURLY@[104; 105) "}" + WHITESPACE@[105; 106) " " + BLOCK_EXPR@[106; 108) + L_CURLY@[106; 107) "{" + R_CURLY@[107; 108) "}" + WHITESPACE@[108; 109) " " + ELSE_KW@[109; 113) "else" + WHITESPACE@[113; 114) " " + BLOCK_EXPR@[114; 116) + L_CURLY@[114; 115) "{" + R_CURLY@[115; 116) "}" + WHITESPACE@[116; 117) "\n" + R_CURLY@[117; 118) "}" + diff --git a/crates/mun_syntax/src/tests/snapshots/parser__literals.snap b/crates/mun_syntax/src/tests/snapshots/parser__literals.snap index e4f893ad0..2f8d223a7 100644 --- a/crates/mun_syntax/src/tests/snapshots/parser__literals.snap +++ b/crates/mun_syntax/src/tests/snapshots/parser__literals.snap @@ -12,7 +12,7 @@ SOURCE_FILE@[0; 110) L_PAREN@[6; 7) "(" R_PAREN@[7; 8) ")" WHITESPACE@[8; 9) " " - BLOCK@[9; 110) + BLOCK_EXPR@[9; 110) L_CURLY@[9; 10) "{" WHITESPACE@[10; 15) "\n " LET_STMT@[15; 28) diff --git a/crates/mun_syntax/src/tests/snapshots/parser__patterns.snap b/crates/mun_syntax/src/tests/snapshots/parser__patterns.snap index 992724a17..d6763b552 100644 --- a/crates/mun_syntax/src/tests/snapshots/parser__patterns.snap +++ b/crates/mun_syntax/src/tests/snapshots/parser__patterns.snap @@ -21,7 +21,7 @@ SOURCE_FILE@[0; 49) IDENT@[10; 16) "number" R_PAREN@[16; 17) ")" WHITESPACE@[17; 18) " " - BLOCK@[18; 49) + BLOCK_EXPR@[18; 49) L_CURLY@[18; 19) "{" WHITESPACE@[19; 23) "\n " LET_STMT@[23; 33) diff --git a/crates/mun_syntax/src/tests/snapshots/parser__unary_expr.snap b/crates/mun_syntax/src/tests/snapshots/parser__unary_expr.snap index ee9971e45..3976ed5f6 100644 --- a/crates/mun_syntax/src/tests/snapshots/parser__unary_expr.snap +++ b/crates/mun_syntax/src/tests/snapshots/parser__unary_expr.snap @@ -12,7 +12,7 @@ SOURCE_FILE@[0; 49) L_PAREN@[6; 7) "(" R_PAREN@[7; 8) ")" WHITESPACE@[8; 9) " " - BLOCK@[9; 49) + BLOCK_EXPR@[9; 49) L_CURLY@[9; 10) "{" WHITESPACE@[10; 15) "\n " LET_STMT@[15; 27)