Skip to content

Commit

Permalink
feat(code_gen): create marshallable wrapper for unmarshallable functions
Browse files Browse the repository at this point in the history
A function cannot be marshalled when one of its parameters or its return
type are a value struct
  • Loading branch information
Wodann committed Mar 7, 2020
1 parent d539b40 commit 9dc084e
Show file tree
Hide file tree
Showing 10 changed files with 248 additions and 80 deletions.
9 changes: 4 additions & 5 deletions crates/mun_codegen/src/code_gen/symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::ir::{
};
use crate::type_info::{TypeGroup, TypeInfo};
use crate::values::{BasicValue, GlobalValue};
use crate::IrDatabase;
use crate::{CodeGenParams, IrDatabase};
use hir::Ty;
use inkwell::{
attributes::Attribute,
Expand Down Expand Up @@ -262,7 +262,6 @@ fn gen_function_info_array<'a, D: IrDatabase>(
functions: impl Iterator<Item = (&'a hir::Function, &'a FunctionValue)>,
) -> GlobalArrayValue {
let function_infos: Vec<StructValue> = functions
.filter(|(f, _)| f.visibility(db) == hir::Visibility::Public)
.map(|(f, value)| {
// Get the function from the cloned module and modify the linkage of the function.
let value = module
Expand Down Expand Up @@ -321,9 +320,9 @@ fn gen_struct_info<D: IrDatabase>(
(0..fields.len()).map(|idx| target_data.offset_of_element(&t, idx as u32).unwrap());
let (field_offsets, _) = gen_u16_array(module, field_offsets);

let field_sizes = fields
.iter()
.map(|field| target_data.get_store_size(&db.type_ir(field.ty(db))));
let field_sizes = fields.iter().map(|field| {
target_data.get_store_size(&db.type_ir(field.ty(db), CodeGenParams { is_extern: false }))
});
let (field_sizes, _) = gen_u16_array(module, field_sizes);

types.struct_info_type.const_named_struct(&[
Expand Down
6 changes: 3 additions & 3 deletions crates/mun_codegen/src/db.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#![allow(clippy::type_repetition_in_bounds)]

use crate::{ir::module::ModuleIR, type_info::TypeInfo, Context};
use crate::{ir::module::ModuleIR, type_info::TypeInfo, CodeGenParams, Context};
use inkwell::types::StructType;
use inkwell::{types::AnyTypeEnum, OptimizationLevel};
use mun_target::spec::Target;
Expand All @@ -22,9 +22,9 @@ pub trait IrDatabase: hir::HirDatabase {
#[salsa::input]
fn target(&self) -> Target;

/// Given a type, return the corresponding IR type.
/// Given a type and code generation parameters, return the corresponding IR type.
#[salsa::invoke(crate::ir::ty::ir_query)]
fn type_ir(&self, ty: hir::Ty) -> AnyTypeEnum;
fn type_ir(&self, ty: hir::Ty, params: CodeGenParams) -> AnyTypeEnum;

/// Given a struct, return the corresponding IR type.
#[salsa::invoke(crate::ir::ty::struct_ty_query)]
Expand Down
4 changes: 2 additions & 2 deletions crates/mun_codegen/src/ir/adt.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//use crate::ir::module::Types;
use crate::ir::try_convert_any_to_basic;
use crate::IrDatabase;
use crate::{CodeGenParams, IrDatabase};
use inkwell::types::{BasicTypeEnum, StructType};

pub(super) fn gen_struct_decl(db: &impl IrDatabase, s: hir::Struct) -> StructType {
Expand All @@ -11,7 +11,7 @@ pub(super) fn gen_struct_decl(db: &impl IrDatabase, s: hir::Struct) -> StructTyp
.iter()
.map(|field| {
let field_type = field.ty(db);
try_convert_any_to_basic(db.type_ir(field_type))
try_convert_any_to_basic(db.type_ir(field_type, CodeGenParams { is_extern: false }))
.expect("could not convert field type")
})
.collect();
Expand Down
172 changes: 119 additions & 53 deletions crates/mun_codegen/src/ir/body.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use crate::intrinsics;
use crate::{ir::dispatch_table::DispatchTable, ir::try_convert_any_to_basic, IrDatabase};
use crate::{
ir::dispatch_table::DispatchTable, ir::try_convert_any_to_basic, CodeGenParams, IrDatabase,
};
use hir::{
ArenaId, ArithOp, BinaryOp, Body, CmpOp, Expr, ExprId, HirDisplay, InferenceResult, Literal,
Name, Ordering, Pat, PatId, Path, Resolution, Resolver, Statement, TypeCtor,
};
use inkwell::{
builder::Builder,
module::Module,
values::{BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, IntValue},
values::{BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, IntValue, StructValue},
AddressSpace, FloatPredicate, IntPredicate,
};
use std::{collections::HashMap, mem, sync::Arc};
Expand Down Expand Up @@ -37,6 +39,7 @@ pub(crate) struct BodyIrGenerator<'a, 'b, D: IrDatabase> {
dispatch_table: &'b DispatchTable,
active_loop: Option<LoopInfo>,
hir_function: hir::Function,
params: CodeGenParams,
}

impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
Expand All @@ -47,6 +50,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
ir_function: FunctionValue,
function_map: &'a HashMap<hir::Function, FunctionValue>,
dispatch_table: &'b DispatchTable,
params: CodeGenParams,
) -> Self {
// Get the type information from the `hir::Function`
let body = hir_function.body(db);
Expand All @@ -72,6 +76,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
dispatch_table,
active_loop: None,
hir_function,
params,
}
}

Expand Down Expand Up @@ -127,6 +132,50 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
}
}

pub fn gen_fn_wrapper(&mut self) {
let fn_sig = self.hir_function.ty(self.db).callable_sig(self.db).unwrap();
let args: Vec<BasicValueEnum> = fn_sig
.params()
.iter()
.enumerate()
.map(|(idx, ty)| {
let param = self.fn_value.get_nth_param(idx as u32).unwrap();
self.opt_deref_value(ty.clone(), param)
})
.collect();

let ret_value = self
.gen_call(self.hir_function, &args)
.try_as_basic_value()
.left();

let call_return_type = &self.infer[self.body.body_expr()];
if !call_return_type.is_never() {
let fn_ret_type = self
.hir_function
.ty(self.db)
.callable_sig(self.db)
.unwrap()
.ret()
.clone();

if fn_ret_type.is_empty() {
self.builder.build_return(None);
} else if let Some(value) = ret_value {
let ret_value = if let Some(hir_struct) = fn_ret_type.as_struct() {
if hir_struct.data(self.db).memory_kind == hir::StructMemoryKind::Value {
self.gen_struct_alloc_on_heap(hir_struct, value.into_struct_value())
} else {
value
}
} else {
value
};
self.builder.build_return(Some(&ret_value));
}
}
}

/// Generates IR for the specified expression. Dependending on the type of expression an IR
/// value is returned.
fn gen_expr(&mut self, expr: ExprId) -> Option<inkwell::values::BasicValueEnum> {
Expand All @@ -152,6 +201,12 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
// Get the callable definition from the map
match self.infer[*callee].as_callable_def() {
Some(hir::CallableDef::Function(def)) => {
// Get all the arguments
let args: Vec<BasicValueEnum> = args
.iter()
.map(|expr| self.gen_expr(*expr).expect("expected a value"))
.collect();

self.gen_call(def, &args).try_as_basic_value().left()
}
Some(hir::CallableDef::Struct(_)) => Some(self.gen_named_tuple_lit(expr, args)),
Expand Down Expand Up @@ -235,37 +290,45 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
hir::StructMemoryKind::Value => struct_lit.into(),
hir::StructMemoryKind::GC => {
// TODO: Root memory in GC
let struct_ir_ty = self.db.struct_ty(hir_struct);
let malloc_fn_ptr = self
.dispatch_table
.gen_intrinsic_lookup(&self.builder, &intrinsics::malloc);
let mem_ptr = self
.builder
.build_call(
malloc_fn_ptr,
&[
struct_ir_ty.size_of().unwrap().into(),
struct_ir_ty.get_alignment().into(),
],
"malloc",
)
.try_as_basic_value()
.left()
.unwrap();
let struct_ptr = self
.builder
.build_bitcast(
mem_ptr,
struct_ir_ty.ptr_type(AddressSpace::Generic),
&hir_struct.name(self.db).to_string(),
)
.into_pointer_value();
self.builder.build_store(struct_ptr, struct_lit);
struct_ptr.into()
self.gen_struct_alloc_on_heap(hir_struct, struct_lit)
}
}
}

fn gen_struct_alloc_on_heap(
&mut self,
hir_struct: hir::Struct,
struct_lit: StructValue,
) -> BasicValueEnum {
let struct_ir_ty = self.db.struct_ty(hir_struct);
let malloc_fn_ptr = self
.dispatch_table
.gen_intrinsic_lookup(&self.builder, &intrinsics::malloc);
let mem_ptr = self
.builder
.build_call(
malloc_fn_ptr,
&[
struct_ir_ty.size_of().unwrap().into(),
struct_ir_ty.get_alignment().into(),
],
"malloc",
)
.try_as_basic_value()
.left()
.unwrap();
let struct_ptr = self
.builder
.build_bitcast(
mem_ptr,
struct_ir_ty.ptr_type(AddressSpace::Generic),
&hir_struct.name(self.db).to_string(),
)
.into_pointer_value();
self.builder.build_store(struct_ptr, struct_lit);
struct_ptr.into()
}

/// Generates IR for a record literal, e.g. `Foo { a: 1.23, b: 4 }`
fn gen_record_lit(
&mut self,
Expand Down Expand Up @@ -349,8 +412,11 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
Pat::Bind { name } => {
let builder = self.new_alloca_builder();
let pat_ty = self.infer[pat].clone();
let ty = try_convert_any_to_basic(self.db.type_ir(pat_ty.clone()))
.expect("expected basic type");
let ty = try_convert_any_to_basic(
self.db
.type_ir(pat_ty.clone(), CodeGenParams { is_extern: false }),
)
.expect("expected basic type");
let ptr = builder.build_alloca(ty, &name.to_string());
self.pat_to_local.insert(pat, ptr);
self.pat_to_name.insert(pat, name.to_string());
Expand Down Expand Up @@ -394,16 +460,22 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
}

/// Given an expression and the type of the expression, optionally dereference the value.
fn opt_deref_value(&mut self, expr: ExprId, value: BasicValueEnum) -> BasicValueEnum {
match &self.infer[expr] {
fn opt_deref_value(&mut self, ty: hir::Ty, value: BasicValueEnum) -> BasicValueEnum {
match ty {
hir::Ty::Apply(hir::ApplicationTy {
ctor: hir::TypeCtor::Struct(s),
..
}) => match s.data(self.db).memory_kind {
hir::StructMemoryKind::GC => {
self.builder.build_load(value.into_pointer_value(), "deref")
}
hir::StructMemoryKind::Value => value,
hir::StructMemoryKind::Value => {
if self.params.is_extern {
self.builder.build_load(value.into_pointer_value(), "deref")
} else {
value
}
}
},
_ => value,
}
Expand Down Expand Up @@ -460,12 +532,12 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
) -> Option<BasicValueEnum> {
let lhs = self
.gen_expr(lhs_expr)
.map(|value| self.opt_deref_value(lhs_expr, value))
.map(|value| self.opt_deref_value(self.infer[lhs_expr].clone(), value))
.expect("no lhs value")
.into_float_value();
let rhs = self
.gen_expr(rhs_expr)
.map(|value| self.opt_deref_value(rhs_expr, value))
.map(|value| self.opt_deref_value(self.infer[rhs_expr].clone(), value))
.expect("no rhs value")
.into_float_value();
match op {
Expand Down Expand Up @@ -519,12 +591,12 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
) -> Option<BasicValueEnum> {
let lhs = self
.gen_expr(lhs_expr)
.map(|value| self.opt_deref_value(lhs_expr, value))
.map(|value| self.opt_deref_value(self.infer[lhs_expr].clone(), value))
.expect("no lhs value")
.into_int_value();
let rhs = self
.gen_expr(rhs_expr)
.map(|value| self.opt_deref_value(lhs_expr, value))
.map(|value| self.opt_deref_value(self.infer[lhs_expr].clone(), value))
.expect("no rhs value")
.into_int_value();
match op {
Expand Down Expand Up @@ -609,19 +681,13 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
}
}

// TODO: Implement me!
fn should_use_dispatch_table(&self) -> bool {
true
// FIXME: When we use the dispatch table, generated wrappers have infinite recursion
!self.params.is_extern
}

/// Generates IR for a function call.
fn gen_call(&mut self, function: hir::Function, args: &[ExprId]) -> CallSiteValue {
// Get all the arguments
let args: Vec<BasicValueEnum> = args
.iter()
.map(|expr| self.gen_expr(*expr).expect("expected a value"))
.collect();

fn gen_call(&mut self, function: hir::Function, args: &[BasicValueEnum]) -> CallSiteValue {
if self.should_use_dispatch_table() {
let ptr_value =
self.dispatch_table
Expand Down Expand Up @@ -649,7 +715,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
// Generate IR for the condition
let condition_ir = self
.gen_expr(condition)
.map(|value| self.opt_deref_value(condition, value))?
.map(|value| self.opt_deref_value(self.infer[condition].clone(), value))?
.into_int_value();

// Generate the code blocks to branch to
Expand Down Expand Up @@ -787,7 +853,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
self.builder.position_at_end(&cond_block);
let condition_ir = self
.gen_expr(condition_expr)
.map(|value| self.opt_deref_value(condition_expr, value));
.map(|value| self.opt_deref_value(self.infer[condition_expr].clone(), value));
if let Some(condition_ir) = condition_ir {
self.builder.build_conditional_branch(
condition_ir.into_int_value(),
Expand Down Expand Up @@ -844,11 +910,11 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
}

fn gen_field(&mut self, _expr: ExprId, receiver_expr: ExprId, name: &Name) -> PointerValue {
let receiver_ty = &self.infer[receiver_expr]
let hir_struct = self.infer[receiver_expr]
.as_struct()
.expect("expected a struct");

let field_idx = receiver_ty
let field_idx = hir_struct
.field(self.db, name)
.expect("expected a struct field")
.id()
Expand All @@ -857,13 +923,13 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {

let receiver_ptr = self.gen_place_expr(receiver_expr);
let receiver_ptr = self
.opt_deref_value(receiver_expr, receiver_ptr.into())
.opt_deref_value(self.infer[receiver_expr].clone(), receiver_ptr.into())
.into_pointer_value();
unsafe {
self.builder.build_struct_gep(
receiver_ptr,
field_idx,
&format!("{}.{}", receiver_ty.name(self.db), name),
&format!("{}.{}", hir_struct.name(self.db), name),
)
}
}
Expand Down
Loading

0 comments on commit 9dc084e

Please sign in to comment.