From 9dc084e443b1723a234e9e1f104df2071f9ee7b3 Mon Sep 17 00:00:00 2001 From: Wodann Date: Sun, 1 Mar 2020 14:52:57 +0100 Subject: [PATCH] feat(code_gen): create marshallable wrapper for unmarshallable functions A function cannot be marshalled when one of its parameters or its return type are a value struct --- crates/mun_codegen/src/code_gen/symbols.rs | 9 +- crates/mun_codegen/src/db.rs | 6 +- crates/mun_codegen/src/ir/adt.rs | 4 +- crates/mun_codegen/src/ir/body.rs | 172 ++++++++++++++------ crates/mun_codegen/src/ir/dispatch_table.rs | 12 +- crates/mun_codegen/src/ir/function.rs | 31 +++- crates/mun_codegen/src/ir/module.rs | 48 +++++- crates/mun_codegen/src/ir/ty.rs | 26 ++- crates/mun_codegen/src/lib.rs | 7 + crates/mun_hir/src/ty.rs | 13 +- 10 files changed, 248 insertions(+), 80 deletions(-) diff --git a/crates/mun_codegen/src/code_gen/symbols.rs b/crates/mun_codegen/src/code_gen/symbols.rs index ffa2f382f..d8954c0ba 100644 --- a/crates/mun_codegen/src/code_gen/symbols.rs +++ b/crates/mun_codegen/src/code_gen/symbols.rs @@ -5,7 +5,7 @@ use crate::ir::{ }; use crate::type_info::{TypeGroup, TypeInfo}; use crate::values::{BasicValue, GlobalValue}; -use crate::IrDatabase; +use crate::{CodeGenParams, IrDatabase}; use hir::Ty; use inkwell::{ attributes::Attribute, @@ -262,7 +262,6 @@ fn gen_function_info_array<'a, D: IrDatabase>( functions: impl Iterator, ) -> GlobalArrayValue { let function_infos: Vec = functions - .filter(|(f, _)| f.visibility(db) == hir::Visibility::Public) .map(|(f, value)| { // Get the function from the cloned module and modify the linkage of the function. let value = module @@ -321,9 +320,9 @@ fn gen_struct_info( (0..fields.len()).map(|idx| target_data.offset_of_element(&t, idx as u32).unwrap()); let (field_offsets, _) = gen_u16_array(module, field_offsets); - let field_sizes = fields - .iter() - .map(|field| target_data.get_store_size(&db.type_ir(field.ty(db)))); + let field_sizes = fields.iter().map(|field| { + target_data.get_store_size(&db.type_ir(field.ty(db), CodeGenParams { is_extern: false })) + }); let (field_sizes, _) = gen_u16_array(module, field_sizes); types.struct_info_type.const_named_struct(&[ diff --git a/crates/mun_codegen/src/db.rs b/crates/mun_codegen/src/db.rs index 5b911aebd..332c4ee73 100644 --- a/crates/mun_codegen/src/db.rs +++ b/crates/mun_codegen/src/db.rs @@ -1,6 +1,6 @@ #![allow(clippy::type_repetition_in_bounds)] -use crate::{ir::module::ModuleIR, type_info::TypeInfo, Context}; +use crate::{ir::module::ModuleIR, type_info::TypeInfo, CodeGenParams, Context}; use inkwell::types::StructType; use inkwell::{types::AnyTypeEnum, OptimizationLevel}; use mun_target::spec::Target; @@ -22,9 +22,9 @@ pub trait IrDatabase: hir::HirDatabase { #[salsa::input] fn target(&self) -> Target; - /// Given a type, return the corresponding IR type. + /// Given a type and code generation parameters, return the corresponding IR type. #[salsa::invoke(crate::ir::ty::ir_query)] - fn type_ir(&self, ty: hir::Ty) -> AnyTypeEnum; + fn type_ir(&self, ty: hir::Ty, params: CodeGenParams) -> AnyTypeEnum; /// Given a struct, return the corresponding IR type. #[salsa::invoke(crate::ir::ty::struct_ty_query)] diff --git a/crates/mun_codegen/src/ir/adt.rs b/crates/mun_codegen/src/ir/adt.rs index b8c00b271..66105d831 100644 --- a/crates/mun_codegen/src/ir/adt.rs +++ b/crates/mun_codegen/src/ir/adt.rs @@ -1,6 +1,6 @@ //use crate::ir::module::Types; use crate::ir::try_convert_any_to_basic; -use crate::IrDatabase; +use crate::{CodeGenParams, IrDatabase}; use inkwell::types::{BasicTypeEnum, StructType}; pub(super) fn gen_struct_decl(db: &impl IrDatabase, s: hir::Struct) -> StructType { @@ -11,7 +11,7 @@ pub(super) fn gen_struct_decl(db: &impl IrDatabase, s: hir::Struct) -> StructTyp .iter() .map(|field| { let field_type = field.ty(db); - try_convert_any_to_basic(db.type_ir(field_type)) + try_convert_any_to_basic(db.type_ir(field_type, CodeGenParams { is_extern: false })) .expect("could not convert field type") }) .collect(); diff --git a/crates/mun_codegen/src/ir/body.rs b/crates/mun_codegen/src/ir/body.rs index 3bef260ce..741fa0521 100644 --- a/crates/mun_codegen/src/ir/body.rs +++ b/crates/mun_codegen/src/ir/body.rs @@ -1,5 +1,7 @@ use crate::intrinsics; -use crate::{ir::dispatch_table::DispatchTable, ir::try_convert_any_to_basic, IrDatabase}; +use crate::{ + ir::dispatch_table::DispatchTable, ir::try_convert_any_to_basic, CodeGenParams, IrDatabase, +}; use hir::{ ArenaId, ArithOp, BinaryOp, Body, CmpOp, Expr, ExprId, HirDisplay, InferenceResult, Literal, Name, Ordering, Pat, PatId, Path, Resolution, Resolver, Statement, TypeCtor, @@ -7,7 +9,7 @@ use hir::{ use inkwell::{ builder::Builder, module::Module, - values::{BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, IntValue}, + values::{BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, IntValue, StructValue}, AddressSpace, FloatPredicate, IntPredicate, }; use std::{collections::HashMap, mem, sync::Arc}; @@ -37,6 +39,7 @@ pub(crate) struct BodyIrGenerator<'a, 'b, D: IrDatabase> { dispatch_table: &'b DispatchTable, active_loop: Option, hir_function: hir::Function, + params: CodeGenParams, } impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { @@ -47,6 +50,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { ir_function: FunctionValue, function_map: &'a HashMap, dispatch_table: &'b DispatchTable, + params: CodeGenParams, ) -> Self { // Get the type information from the `hir::Function` let body = hir_function.body(db); @@ -72,6 +76,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { dispatch_table, active_loop: None, hir_function, + params, } } @@ -127,6 +132,50 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { } } + pub fn gen_fn_wrapper(&mut self) { + let fn_sig = self.hir_function.ty(self.db).callable_sig(self.db).unwrap(); + let args: Vec = fn_sig + .params() + .iter() + .enumerate() + .map(|(idx, ty)| { + let param = self.fn_value.get_nth_param(idx as u32).unwrap(); + self.opt_deref_value(ty.clone(), param) + }) + .collect(); + + let ret_value = self + .gen_call(self.hir_function, &args) + .try_as_basic_value() + .left(); + + let call_return_type = &self.infer[self.body.body_expr()]; + if !call_return_type.is_never() { + let fn_ret_type = self + .hir_function + .ty(self.db) + .callable_sig(self.db) + .unwrap() + .ret() + .clone(); + + if fn_ret_type.is_empty() { + self.builder.build_return(None); + } else if let Some(value) = ret_value { + let ret_value = if let Some(hir_struct) = fn_ret_type.as_struct() { + if hir_struct.data(self.db).memory_kind == hir::StructMemoryKind::Value { + self.gen_struct_alloc_on_heap(hir_struct, value.into_struct_value()) + } else { + value + } + } else { + value + }; + self.builder.build_return(Some(&ret_value)); + } + } + } + /// Generates IR for the specified expression. Dependending on the type of expression an IR /// value is returned. fn gen_expr(&mut self, expr: ExprId) -> Option { @@ -152,6 +201,12 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { // Get the callable definition from the map match self.infer[*callee].as_callable_def() { Some(hir::CallableDef::Function(def)) => { + // Get all the arguments + let args: Vec = args + .iter() + .map(|expr| self.gen_expr(*expr).expect("expected a value")) + .collect(); + self.gen_call(def, &args).try_as_basic_value().left() } Some(hir::CallableDef::Struct(_)) => Some(self.gen_named_tuple_lit(expr, args)), @@ -235,37 +290,45 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { hir::StructMemoryKind::Value => struct_lit.into(), hir::StructMemoryKind::GC => { // TODO: Root memory in GC - let struct_ir_ty = self.db.struct_ty(hir_struct); - let malloc_fn_ptr = self - .dispatch_table - .gen_intrinsic_lookup(&self.builder, &intrinsics::malloc); - let mem_ptr = self - .builder - .build_call( - malloc_fn_ptr, - &[ - struct_ir_ty.size_of().unwrap().into(), - struct_ir_ty.get_alignment().into(), - ], - "malloc", - ) - .try_as_basic_value() - .left() - .unwrap(); - let struct_ptr = self - .builder - .build_bitcast( - mem_ptr, - struct_ir_ty.ptr_type(AddressSpace::Generic), - &hir_struct.name(self.db).to_string(), - ) - .into_pointer_value(); - self.builder.build_store(struct_ptr, struct_lit); - struct_ptr.into() + self.gen_struct_alloc_on_heap(hir_struct, struct_lit) } } } + fn gen_struct_alloc_on_heap( + &mut self, + hir_struct: hir::Struct, + struct_lit: StructValue, + ) -> BasicValueEnum { + let struct_ir_ty = self.db.struct_ty(hir_struct); + let malloc_fn_ptr = self + .dispatch_table + .gen_intrinsic_lookup(&self.builder, &intrinsics::malloc); + let mem_ptr = self + .builder + .build_call( + malloc_fn_ptr, + &[ + struct_ir_ty.size_of().unwrap().into(), + struct_ir_ty.get_alignment().into(), + ], + "malloc", + ) + .try_as_basic_value() + .left() + .unwrap(); + let struct_ptr = self + .builder + .build_bitcast( + mem_ptr, + struct_ir_ty.ptr_type(AddressSpace::Generic), + &hir_struct.name(self.db).to_string(), + ) + .into_pointer_value(); + self.builder.build_store(struct_ptr, struct_lit); + struct_ptr.into() + } + /// Generates IR for a record literal, e.g. `Foo { a: 1.23, b: 4 }` fn gen_record_lit( &mut self, @@ -349,8 +412,11 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { Pat::Bind { name } => { let builder = self.new_alloca_builder(); let pat_ty = self.infer[pat].clone(); - let ty = try_convert_any_to_basic(self.db.type_ir(pat_ty.clone())) - .expect("expected basic type"); + let ty = try_convert_any_to_basic( + self.db + .type_ir(pat_ty.clone(), CodeGenParams { is_extern: false }), + ) + .expect("expected basic type"); let ptr = builder.build_alloca(ty, &name.to_string()); self.pat_to_local.insert(pat, ptr); self.pat_to_name.insert(pat, name.to_string()); @@ -394,8 +460,8 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { } /// Given an expression and the type of the expression, optionally dereference the value. - fn opt_deref_value(&mut self, expr: ExprId, value: BasicValueEnum) -> BasicValueEnum { - match &self.infer[expr] { + fn opt_deref_value(&mut self, ty: hir::Ty, value: BasicValueEnum) -> BasicValueEnum { + match ty { hir::Ty::Apply(hir::ApplicationTy { ctor: hir::TypeCtor::Struct(s), .. @@ -403,7 +469,13 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { hir::StructMemoryKind::GC => { self.builder.build_load(value.into_pointer_value(), "deref") } - hir::StructMemoryKind::Value => value, + hir::StructMemoryKind::Value => { + if self.params.is_extern { + self.builder.build_load(value.into_pointer_value(), "deref") + } else { + value + } + } }, _ => value, } @@ -460,12 +532,12 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { ) -> Option { let lhs = self .gen_expr(lhs_expr) - .map(|value| self.opt_deref_value(lhs_expr, value)) + .map(|value| self.opt_deref_value(self.infer[lhs_expr].clone(), value)) .expect("no lhs value") .into_float_value(); let rhs = self .gen_expr(rhs_expr) - .map(|value| self.opt_deref_value(rhs_expr, value)) + .map(|value| self.opt_deref_value(self.infer[rhs_expr].clone(), value)) .expect("no rhs value") .into_float_value(); match op { @@ -519,12 +591,12 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { ) -> Option { let lhs = self .gen_expr(lhs_expr) - .map(|value| self.opt_deref_value(lhs_expr, value)) + .map(|value| self.opt_deref_value(self.infer[lhs_expr].clone(), value)) .expect("no lhs value") .into_int_value(); let rhs = self .gen_expr(rhs_expr) - .map(|value| self.opt_deref_value(lhs_expr, value)) + .map(|value| self.opt_deref_value(self.infer[lhs_expr].clone(), value)) .expect("no rhs value") .into_int_value(); match op { @@ -609,19 +681,13 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { } } - // TODO: Implement me! fn should_use_dispatch_table(&self) -> bool { - true + // FIXME: When we use the dispatch table, generated wrappers have infinite recursion + !self.params.is_extern } /// Generates IR for a function call. - fn gen_call(&mut self, function: hir::Function, args: &[ExprId]) -> CallSiteValue { - // Get all the arguments - let args: Vec = args - .iter() - .map(|expr| self.gen_expr(*expr).expect("expected a value")) - .collect(); - + fn gen_call(&mut self, function: hir::Function, args: &[BasicValueEnum]) -> CallSiteValue { if self.should_use_dispatch_table() { let ptr_value = self.dispatch_table @@ -649,7 +715,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { // Generate IR for the condition let condition_ir = self .gen_expr(condition) - .map(|value| self.opt_deref_value(condition, value))? + .map(|value| self.opt_deref_value(self.infer[condition].clone(), value))? .into_int_value(); // Generate the code blocks to branch to @@ -787,7 +853,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { self.builder.position_at_end(&cond_block); let condition_ir = self .gen_expr(condition_expr) - .map(|value| self.opt_deref_value(condition_expr, value)); + .map(|value| self.opt_deref_value(self.infer[condition_expr].clone(), value)); if let Some(condition_ir) = condition_ir { self.builder.build_conditional_branch( condition_ir.into_int_value(), @@ -844,11 +910,11 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { } fn gen_field(&mut self, _expr: ExprId, receiver_expr: ExprId, name: &Name) -> PointerValue { - let receiver_ty = &self.infer[receiver_expr] + let hir_struct = self.infer[receiver_expr] .as_struct() .expect("expected a struct"); - let field_idx = receiver_ty + let field_idx = hir_struct .field(self.db, name) .expect("expected a struct field") .id() @@ -857,13 +923,13 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { let receiver_ptr = self.gen_place_expr(receiver_expr); let receiver_ptr = self - .opt_deref_value(receiver_expr, receiver_ptr.into()) + .opt_deref_value(self.infer[receiver_expr].clone(), receiver_ptr.into()) .into_pointer_value(); unsafe { self.builder.build_struct_gep( receiver_ptr, field_idx, - &format!("{}.{}", receiver_ty.name(self.db), name), + &format!("{}.{}", hir_struct.name(self.db), name), ) } } diff --git a/crates/mun_codegen/src/ir/dispatch_table.rs b/crates/mun_codegen/src/ir/dispatch_table.rs index 33f924f34..baaef9f7f 100644 --- a/crates/mun_codegen/src/ir/dispatch_table.rs +++ b/crates/mun_codegen/src/ir/dispatch_table.rs @@ -1,6 +1,6 @@ use crate::intrinsics; use crate::values::FunctionValue; -use crate::IrDatabase; +use crate::{CodeGenParams, IrDatabase}; use inkwell::module::Module; use inkwell::types::{BasicTypeEnum, FunctionType}; use inkwell::values::{BasicValueEnum, PointerValue}; @@ -225,7 +225,10 @@ impl<'a, D: IrDatabase> DispatchTableBuilder<'a, D> { let name = function.name(self.db).to_string(); let hir_type = function.ty(self.db); let sig = hir_type.callable_sig(self.db).unwrap(); - let ir_type = self.db.type_ir(hir_type).into_function_type(); + let ir_type = self + .db + .type_ir(hir_type, CodeGenParams { is_extern: false }) + .into_function_type(); let arg_types = sig .params() .iter() @@ -282,6 +285,11 @@ impl<'a, D: IrDatabase> DispatchTableBuilder<'a, D> { self.collect_expr(body.body_expr(), body, infer); } + /// Collect the call expression from the body of a wrapper for the specified function. + pub fn collect_wrapper_body(&mut self, _function: hir::Function) { + self.collect_intrinsic(&intrinsics::malloc) + } + /// This creates the final DispatchTable with all *called* functions from within the module /// # Parameters /// * **functions**: Mapping of *defined* Mun functions to their respective IR values. diff --git a/crates/mun_codegen/src/ir/function.rs b/crates/mun_codegen/src/ir/function.rs index 357d0b92d..c3f7b0bc9 100644 --- a/crates/mun_codegen/src/ir/function.rs +++ b/crates/mun_codegen/src/ir/function.rs @@ -1,7 +1,7 @@ use crate::ir::body::BodyIrGenerator; use crate::ir::dispatch_table::DispatchTable; use crate::values::FunctionValue; -use crate::{IrDatabase, Module, OptimizationLevel}; +use crate::{CodeGenParams, IrDatabase, Module, OptimizationLevel}; use inkwell::passes::{PassManager, PassManagerBuilder}; use inkwell::types::AnyTypeEnum; @@ -30,9 +30,10 @@ pub(crate) fn gen_signature( db: &impl IrDatabase, f: hir::Function, module: &Module, + params: CodeGenParams, ) -> FunctionValue { let name = f.name(db).to_string(); - if let AnyTypeEnum::FunctionType(ty) = db.type_ir(f.ty(db)) { + if let AnyTypeEnum::FunctionType(ty) = db.type_ir(f.ty(db), params) { module.add_function(&name, ty, None) } else { panic!("not a function type") @@ -55,9 +56,35 @@ pub(crate) fn gen_body<'a, 'b, D: IrDatabase>( llvm_function, llvm_functions, dispatch_table, + CodeGenParams { is_extern: false }, ); code_gen.gen_fn_body(); llvm_function } + +/// Generates the body of a wrapper around `hir::Function` for its associated +/// `FunctionValue` +pub(crate) fn gen_wrapper_body<'a, 'b, D: IrDatabase>( + db: &'a D, + hir_function: hir::Function, + llvm_function: FunctionValue, + module: &'a Module, + llvm_functions: &'a HashMap, + dispatch_table: &'b DispatchTable, +) -> FunctionValue { + let mut code_gen = BodyIrGenerator::new( + db, + module, + hir_function, + llvm_function, + llvm_functions, + dispatch_table, + CodeGenParams { is_extern: true }, + ); + + code_gen.gen_fn_wrapper(); + + llvm_function +} diff --git a/crates/mun_codegen/src/ir/module.rs b/crates/mun_codegen/src/ir/module.rs index d3d54578a..0a303e992 100644 --- a/crates/mun_codegen/src/ir/module.rs +++ b/crates/mun_codegen/src/ir/module.rs @@ -2,7 +2,7 @@ use super::adt; use crate::ir::dispatch_table::{DispatchTable, DispatchTableBuilder}; use crate::ir::function; use crate::type_info::TypeInfo; -use crate::IrDatabase; +use crate::{CodeGenParams, IrDatabase}; use hir::{FileId, ModuleDef}; use inkwell::{module::Module, values::FunctionValue}; use std::collections::{HashMap, HashSet}; @@ -47,6 +47,7 @@ pub(crate) fn ir_query(db: &impl IrDatabase, file_id: FileId) -> Arc { // Generate all the function signatures let mut functions = HashMap::new(); + let mut wrappers = HashMap::new(); let mut dispatch_table_builder = DispatchTableBuilder::new(db, &llvm_module); for def in db.module_data(file_id).definitions() { // TODO: Remove once we have more ModuleDef variants @@ -65,13 +66,31 @@ pub(crate) fn ir_query(db: &impl IrDatabase, file_id: FileId) -> Arc { } // Construct the function signature - let fun = function::gen_signature(db, *f, &llvm_module); + let fun = function::gen_signature( + db, + *f, + &llvm_module, + CodeGenParams { is_extern: false }, + ); functions.insert(*f, fun); // Add calls to the dispatch table let body = f.body(db); let infer = f.infer(db); dispatch_table_builder.collect_body(&body, &infer); + + if f.data(db).visibility() != hir::Visibility::Private && !fn_sig.marshallable(db) { + let wrapper_fun = function::gen_signature( + db, + *f, + &llvm_module, + CodeGenParams { is_extern: true }, + ); + wrappers.insert(*f, wrapper_fun); + + // Add calls from the function's wrapper to the dispatch table + dispatch_table_builder.collect_wrapper_body(*f); + } } _ => {} } @@ -94,6 +113,18 @@ pub(crate) fn ir_query(db: &impl IrDatabase, file_id: FileId) -> Arc { fn_pass_manager.run_on(llvm_function); } + for (hir_function, llvm_function) in wrappers.iter() { + function::gen_wrapper_body( + db, + *hir_function, + *llvm_function, + &llvm_module, + &functions, + &dispatch_table, + ); + fn_pass_manager.run_on(llvm_function); + } + // Dispatch entries can include previously unchecked intrinsics for entry in dispatch_table.entries().iter() { // Collect argument types @@ -106,10 +137,21 @@ pub(crate) fn ir_query(db: &impl IrDatabase, file_id: FileId) -> Arc { } } + // Filter private methods + let mut api: HashMap = functions + .into_iter() + .filter(|(f, _)| f.visibility(db) != hir::Visibility::Private) + .collect(); + + // Replace non-marshallable functions with their marshallable wrappers + for (hir_function, llvm_function) in wrappers { + api.insert(hir_function, llvm_function); + } + Arc::new(ModuleIR { file_id, llvm_module, - functions, + functions: api, types, dispatch_table, }) diff --git a/crates/mun_codegen/src/ir/ty.rs b/crates/mun_codegen/src/ir/ty.rs index 280fe0595..165b907c0 100644 --- a/crates/mun_codegen/src/ir/ty.rs +++ b/crates/mun_codegen/src/ir/ty.rs @@ -1,14 +1,14 @@ use super::try_convert_any_to_basic; use crate::{ type_info::{TypeGroup, TypeInfo}, - IrDatabase, + CodeGenParams, IrDatabase, }; use hir::{ApplicationTy, CallableDef, Ty, TypeCtor}; use inkwell::types::{AnyTypeEnum, BasicType, BasicTypeEnum, StructType}; use inkwell::AddressSpace; /// Given a mun type, construct an LLVM IR type -pub(crate) fn ir_query(db: &impl IrDatabase, ty: Ty) -> AnyTypeEnum { +pub(crate) fn ir_query(db: &impl IrDatabase, ty: Ty, params: CodeGenParams) -> AnyTypeEnum { let context = db.context(); match ty { Ty::Empty => AnyTypeEnum::StructType(context.struct_type(&[], false)), @@ -18,17 +18,19 @@ pub(crate) fn ir_query(db: &impl IrDatabase, ty: Ty) -> AnyTypeEnum { TypeCtor::Bool => AnyTypeEnum::IntType(context.bool_type()), TypeCtor::FnDef(def @ CallableDef::Function(_)) => { let ty = db.callable_sig(def); - let params: Vec = ty + let param_tys: Vec = ty .params() .iter() - .map(|p| try_convert_any_to_basic(db.type_ir(p.clone())).unwrap()) + .map(|p| { + try_convert_any_to_basic(db.type_ir(p.clone(), params.clone())).unwrap() + }) .collect(); let fn_type = match ty.ret() { - Ty::Empty => context.void_type().fn_type(¶ms, false), - ty => try_convert_any_to_basic(db.type_ir(ty.clone())) + Ty::Empty => context.void_type().fn_type(¶m_tys, false), + ty => try_convert_any_to_basic(db.type_ir(ty.clone(), params)) .expect("could not convert return value") - .fn_type(¶ms, false), + .fn_type(¶m_tys, false), }; AnyTypeEnum::FunctionType(fn_type) @@ -37,7 +39,13 @@ pub(crate) fn ir_query(db: &impl IrDatabase, ty: Ty) -> AnyTypeEnum { let struct_ty = db.struct_ty(s); match s.data(db).memory_kind { hir::StructMemoryKind::GC => struct_ty.ptr_type(AddressSpace::Generic).into(), - hir::StructMemoryKind::Value => struct_ty.into(), + hir::StructMemoryKind::Value => { + if params.is_extern { + struct_ty.ptr_type(AddressSpace::Generic).into() + } else { + struct_ty.into() + } + } } } _ => unreachable!(), @@ -51,7 +59,7 @@ pub fn struct_ty_query(db: &impl IrDatabase, s: hir::Struct) -> StructType { let name = s.name(db).to_string(); for field in s.fields(db).iter() { // Ensure that salsa's cached value incorporates the struct fields - let _field_type_ir = db.type_ir(field.ty(db)); + let _field_type_ir = db.type_ir(field.ty(db), CodeGenParams { is_extern: false }); } db.context().opaque_struct_type(&name) diff --git a/crates/mun_codegen/src/lib.rs b/crates/mun_codegen/src/lib.rs index f674be161..8544fb805 100644 --- a/crates/mun_codegen/src/lib.rs +++ b/crates/mun_codegen/src/lib.rs @@ -19,3 +19,10 @@ pub use crate::{ code_gen::write_module_shared_object, db::{IrDatabase, IrDatabaseStorage}, }; + +#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)] +pub struct CodeGenParams { + /// Whether generated code should support extern function calls. + /// This allows function parameters with `struct(value)` types to be marshalled. + is_extern: bool, +} diff --git a/crates/mun_hir/src/ty.rs b/crates/mun_hir/src/ty.rs index 6f3f84cf3..0c253455a 100644 --- a/crates/mun_hir/src/ty.rs +++ b/crates/mun_hir/src/ty.rs @@ -5,7 +5,7 @@ mod op; use crate::display::{HirDisplay, HirFormatter}; use crate::ty::infer::TypeVarId; use crate::ty::lower::fn_sig_for_struct_constructor; -use crate::{HirDatabase, Struct}; +use crate::{HirDatabase, Struct, StructMemoryKind}; pub(crate) use infer::infer_query; pub use infer::InferenceResult; pub(crate) use lower::{callable_item_sig, fn_sig_for_fn, type_for_def, CallableDef, TypableDef}; @@ -172,6 +172,17 @@ impl FnSig { pub fn ret(&self) -> &Ty { &self.params_and_return[self.params_and_return.len() - 1] } + + pub fn marshallable(&self, db: &impl HirDatabase) -> bool { + for ty in self.params_and_return.iter() { + if let Some(s) = ty.as_struct() { + if s.data(db).memory_kind == StructMemoryKind::Value { + return false; + } + } + } + true + } } impl HirDisplay for Ty {