Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

__multi3 optimization #4077

Open
shamatar opened this issue Apr 27, 2022 · 19 comments · May be fixed by #8653
Open

__multi3 optimization #4077

shamatar opened this issue Apr 27, 2022 · 19 comments · May be fixed by #8653
Labels
cranelift Issues related to the Cranelift code generator enhancement

Comments

@shamatar
Copy link

Feature

Implement an optimization pass that would eliminate the __multi3 function from WASM binary during JIT by replacing it with ISA specific (mainly for x86_64 and arm64) sequences, and then inline such sequences into callsites that would allow further optimizations

Benefit

A lot of code dealing with cryptography would benefit form faster full width u64 multiplications where such __multi3 arises

Implementation

If someone would give a few hints about where to start I'd try to implement it by myself

Alternatives

Not that I'm aware of. Patching into calling come native library function is a huge overhead for modern CPUs (4 cycles for x86_64 for e.g. mulx or mul), and while it would be faster most likely, it's still far from optimal case on a hot path

As an example a simple multiply-add-carry function like a*b + c + carry -> (high, low) that accumulates into u128 without overflows compiles down to the listing below, and it can be a good test subject (transformed from wasm into wat, may be not the best readable)

(module
  (type (;0;) (func (param i32 i64 i64 i64 i64)))
  (func $mac (type 0) (param i32 i64 i64 i64 i64)
    (local i32)
    global.get $__stack_pointer
    i32.const 16
    i32.sub
    local.tee 5
    global.set $__stack_pointer
    local.get 5
    local.get 2
    i64.const 0
    local.get 1
    i64.const 0
    call $__multi3
    local.get 0
    local.get 5
    i64.load
    local.tee 2
    local.get 3
    i64.add
    local.tee 3
    local.get 4
    i64.add
    local.tee 4
    i64.store
    local.get 0
    local.get 5
    i32.const 8
    i32.add
    i64.load
    local.get 3
    local.get 2
    i64.lt_u
    i64.extend_i32_u
    i64.add
    local.get 4
    local.get 3
    i64.lt_u
    i64.extend_i32_u
    i64.add
    i64.store offset=8
    local.get 5
    i32.const 16
    i32.add
    global.set $__stack_pointer
  )
  (func $__multi3 (type 0) (param i32 i64 i64 i64 i64)
    (local i64 i64 i64 i64 i64 i64)
    local.get 0
    local.get 3
    i64.const 4294967295
    i64.and
    local.tee 5
    local.get 1
    i64.const 4294967295
    i64.and
    local.tee 6
    i64.mul
    local.tee 7
    local.get 5
    local.get 1
    i64.const 32
    i64.shr_u
    local.tee 8
    i64.mul
    local.tee 9
    local.get 3
    i64.const 32
    i64.shr_u
    local.tee 10
    local.get 6
    i64.mul
    i64.add
    local.tee 5
    i64.const 32
    i64.shl
    i64.add
    local.tee 6
    i64.store
    local.get 0
    local.get 10
    local.get 8
    i64.mul
    local.get 5
    local.get 9
    i64.lt_u
    i64.extend_i32_u
    i64.const 32
    i64.shl
    local.get 5
    i64.const 32
    i64.shr_u
    i64.or
    i64.add
    local.get 6
    local.get 7
    i64.lt_u
    i64.extend_i32_u
    i64.add
    local.get 4
    local.get 1
    i64.mul
    local.get 3
    local.get 2
    i64.mul
    i64.add
    i64.add
    i64.store offset=8
  )
  (table (;0;) 1 1 funcref)
  (memory (;0;) 16)
  (global $__stack_pointer (mut i32) i32.const 1048576)
  (global (;1;) i32 i32.const 1048576)
  (global (;2;) i32 i32.const 1048576)
  (export "memory" (memory 0))
  (export "mac" (func $mac))
  (export "__data_end" (global 1))
  (export "__heap_base" (global 2))
)
@cfallin
Copy link
Member

cfallin commented Apr 27, 2022

Hi @shamatar -- thanks for raising this issue! I agree that the lack of "architecture-specific acceleration" for wide arithmetic lowered into Wasm bytecode is suboptimal and annoying.

On the other hand, somehow recognizing that a Wasm module's internal function happens to be the __multi3 function from the Wasm toolchain's runtime, and then actually replacing that function with another function, seems very likely to me to break the abstraction boundary (and hence security and portability properties) that the Wasm runtime provides. Either we check that it is exactly the __multi3 that we expect (same Wasm bytecode body), in which case this is a very brittle optimization (breaks with any little Wasm-toolchain implementation detail change), or we don't, in which case we might replace the code with an implementation that does something different (incorrectly). So, for certain, we would not implement logic that somehow recognizes __multi3 by name as special.

I think that there is a way around this though: we could pattern-match the actual operations and use the faster architecture-specific lowering when possible. In this specific case, a pattern could match whatever subgraphs correspond to the fast x64 instructions (mulx and friends?). I don't know enough about the state of the art of computer arithmetic implementations on x64 or aarch64 to suggest specific mappings.

The place to put such lowerings would be in the ISLE DSL (see e.g. cranelift/codegen/src/isa/x64/lower.isle); we'd be happy to help if you want to try to implement something here.

@shamatar
Copy link
Author

Hey @cfallin

Thank you for a detailed response. I was only mentioning __multi3 replacement by name as I have seen in one of the old threads an idea to call external library for it for a speedup. But if name matching is potentially too fragile (it may be safely-enough allowed under some feature flag if such flag would exist :) ) I'd try to implement matching-by-logic. __multi3 body looks like a standard u64 -> u32 and u32 * u32 -> u64 approach for simulation and hopefully should be catchable.

As for "state of the art" - the best one can get is adx + mulx combo for generic case I believe, but e.g. llvm never emits mulx, so I'd start with trying to optimize into add + mul for x86_64

I'm not use that a file you have mentioned is a right place as

(rule (lower (has_type (fits_in_64 ty) (imul x y)))
has a rule for full width multiplication lowering already, but the optimization first needs to catch __multi3, and emit just imul x y for it + inline into callsite. May be it also requires more than one pass, but the first step is to get to just imul x y

@cfallin
Copy link
Member

cfallin commented Apr 27, 2022

catch __multi3, and emit just imul x y for it + inline into callsite.

That's true, one way this could work in the future is to recognize the whole function and replace it with an imul.i128 (if that's what you mean by "catch `__multi3, and emit just imul"?) but that would I think hit the brittleness issues I mentioned.

Maybe the right way to start is to draw the dataflow graph of the bytecode you listed above, and find subgraphs that correspond to what adx and mulx do? If we can write a pattern that matches each of those, maybe we can get something at least slightly better than what we have today.

Eventually we will also have an inliner (it's really an orthogonal improvement) and when we do, it could combine with good-enough pattern matching on the function body here to produce an inlined sequence of instructions. In other words, getting to a tight sequence of instructions, and inlining that at the callsite, are two separate problems so let's solve each separately.

cc @abrown and @jlb6740 for thoughts on the x86-64 lowering here...

@shamatar
Copy link
Author

to recognize the whole function and replace it - yes, it's a final goal, and ideally function body will be patter-matched somehow (and not by name). My impression on lower.isle file is that it's a definition of "simple" rules to lower some short sequences of CLIF instructions to machine instructions, and not the rules for "large" pattern matching.

I'll make a dataflow, just need some time

@shamatar
Copy link
Author

More precisely a rule

;; `i64` and smaller.

;; Multiply two registers.
(rule (lower (has_type (fits_in_64 ty) (imul x y)))
      (x64_mul ty x y))

Would match u64 * u64 multiplication, but I'm also not sure if this rule is about getting low bits only (half width mul) or all of them (full width) as it doesn't reflect anything about return type.

__multi3 indeed has i64.mul in it, but inputs to this i64.mul are 32-bit values, so it's naturally full-width mul. And I'd like to "guess" that larger piece of code is indeed full width u64 multiplication that just doesn't have a WASM instruction :)

@abrown
Copy link
Contributor

abrown commented Apr 28, 2022

(Catching up on this thread...) For context, the code behind the WAT that you posted is probably coming from somewhere like this in LLVM: https://github.com/llvm-mirror/compiler-rt/blob/master/lib/builtins/multi3.c. There is quite a bit of shifting and masking there which, as we can see in the WAT, get translated directly to WebAssembly.

I'd be interested to understand what library is being used and how you're using it (e.g., how come __multi3 is on the hot path?). Here's why:
a. it could be that the library could be tweaked for better compilation to Wasm (e.g., should the algorithm use SIMD instead?)
b. if it is a widely-used library, then it motivates improving this sequence in Wasmtime (or even in Wasm itself)

To optimize this in Wasmtime, I suspect the bottom-up approach (i.e., attempting to get rid of the extra shifting and masking) is going to be "good enough" relative to the code sequence currently emitted by Clang in native code (@cfallin has that) and will be less brittle than trying to match the entire function or the called name. But another option might be to attempt to optimize this at the LLVM layer: perhaps it does not need to be as literal about how it translates the C code I linked to to the WAT @shamatar posted.

@shamatar
Copy link
Author

shamatar commented Apr 28, 2022

I've made a minimal example of the typical u64 * u64 + u64 + u64 operation that is non-overflowing due to output range being (2^64 - 1) * (2^64 - 1) + 2 * (2^64 - 1) = 2^128 - 1

#![feature(bigint_helper_methods)]

#[repr(C)]
pub struct U128Pair {
    pub low: u64,
    pub high: u64
}

// #[no_mangle]
// pub extern "C" fn mac(a: u64, b: u64, c: u64, carry: u64) -> U128Pair {
//     let result = (a as u128) * (b as u128) + (c as u128) + (carry as u128);

//     U128Pair {
//         low: result as u64,
//         high: (result >> 64) as u64
//     }
// }

#[no_mangle]
pub extern "C" fn mac(a: u64, b: u64, c: u64, carry: u64) -> U128Pair {
    let (low, high) = a.carrying_mul(b, c);
    let (low, of) = low.overflowing_add(carry);
    let high = high.wrapping_add(of as u64);
    
    U128Pair {
        low,
        high,
    }
}

both the commented code and one using more explicit operations provides the same assembly. This is a typical way to do big integer math on CPU. SIMD optimizations are possible, but a separate field of the art due to non-carry SIMD, non-widening SIMD, and CPU frequency quircks if SIMD is used e.g. on Intel. So my example is an "average case" and solving exactly this part would be the most beneficial (I'm actually not in the need to have a speedup in this problem for any production code, but it's a nice free time task)

So compiler fits __multi3 in there as half-width u128 multiplication. Even though the name is the same, it's actually not used in "full power" since high parts of arguments are constant zeroes. I think it's possible to match the internals of __multi3, but it's still more then just elimination of shifts - it must be replaced just 1 machine instruction + stack manipulation

export function mac(a:{ a:long, b:long }, b:long, c:long, d:long, e:long) {
  var f:long_ptr = stack_pointer - 16;
  stack_pointer = f;
  multi3(f, c, 0L, b, 0L);
  a.a = (e = (d = (c = f[0]) + d) + e);
  a.b =
    i64_extend_i32_u(e < d) + (f + 8)[0]:long + i64_extend_i32_u(d < c);
  stack_pointer = f + 16;
}

function multi3(a:{ a:long, b:long }, b:long, c:long, d:long, e:long) {
  var f:long;
  var g:long;
  var h:long;
  var i:long;
  var j:long;
  var k:long;
  a.a =
    (g = (h = (f = d & 4294967295L) * (g = b & 4294967295L)) +
         ((f = (j = f * (i = b >> 32L)) + (k = d >> 32L) * g) << 32L));
  a.b =
    k * i + (i64_extend_i32_u(f < j) << 32L | f >> 32L) +
    i64_extend_i32_u(g < h) +
    e * b + d * c;
}

Note: it's actually only possible to replace the way how __multi3 is used in here, and not in general, since it's indeed full-width multiplication for u64 while expressed as half-width mul for u128 with constant zero high parts multi3(f, c, 0L, b, 0L);

@shamatar
Copy link
Author

I should add that even though it may be possible to fine tune LLVM too (even further from my world), it's still not possible to generate WASM code that would be better than internals of __multi3 (we can only remove e * b + d * c; subexpression as trivial multiplications by 0), so the problem will move from matching current __multi3 internals into matching just some other similar function body

@cfallin
Copy link
Member

cfallin commented Apr 28, 2022

My impression on lower.isle file is that it's a definition of "simple" rules to lower some short sequences of CLIF instructions to machine instructions, and not the rules for "large" pattern matching.

It's potentially both! We're still in the midst of translating all of the basic lowering into ISLE, so that is most of what we have. But the design goal is absolutely to allow arbitrarily complex patterns, and we hope to grow the library of patterns we match and optimize over time.

I took a look at the gcc output on x86-64 for an __int128 * __int128 case (in C) and I saw:

	movq	%rdx, %r8
	movq	%rdx, %rax
	mulq	%rdi
	imulq	%r8, %rsi
	addq	%rsi, %rdx
	imulq	%rdi, %rcx
	addq	%rcx, %rdx

so it should be possible to do a lot better here, and without recent extensions like mulx/adx.

@shamatar
Copy link
Author

so it should be possible to do a lot better here, and without recent extensions like mulx/adx.

Yes, it's kind of possible due to two carry chains (in my example of u64 * u64 + u64 + u64), but it would be the next step (and adx is much more useful for even "larger" math, like u256 full width multiplication).

So my example's "optimal" code is like

mac:
        mov     r8, rdx
        mov     rax, rsi
        mul     rdi
        add     rax, r8
        adc     rdx, 0
        add     rax, rcx
        adc     rdx, 0
        ret

where adx will allow to speedup two carry propagation chains. mulx allows more flexible register allocation, but none of the compilers uses it in practice as far as I known

@sparker-arm
Copy link
Contributor

Since cranelift supports i128, maybe we could perform clif-to-clif transforms to generate a i128 mul (if a backend wants that). That way we can avoid each backend having to add complicated matching patterns. Plus, at least for aarch64, there's already a lot of existing support for i128.

@shamatar
Copy link
Author

shamatar commented May 4, 2022

I'm still trying to understand pattern-matching and have it (at least fragile) to capture the full body of __multi3. Then it can be replaced by i128 half-width multiplication or something else. Then I'd have to make another large step to inline it and also constant-fold it (since in practice it's not i128 half-mul, but i64 full-mul), that is most likely even more complex

@shamatar
Copy link
Author

shamatar commented May 6, 2022

I've tried to insert an optimization into cranelift directly. I should be able to match the full body of the generated __multi3 and transform it into i128 multiplication. Please check the #4106

@shamatar
Copy link
Author

shamatar commented May 7, 2022

Made an initial functional PR, folds from

block0(v0: i64, v1: i64, v2: i32, v3: i64, v4: i64, v5: i64, v6: i64):
@5130                               v7 = iconst.i64 0
@5136                               v8 = iconst.i64 0xffff_ffff
@513c                               v9 = band_imm v5, 0xffff_ffff
@5141                               v10 = iconst.i64 0xffff_ffff
@5147                               v11 = band_imm v3, 0xffff_ffff
@514a                               v12 = imul v9, v11
@5151                               v13 = iconst.i64 32
@5153                               v14 = ushr_imm v3, 32
@5156                               v15 = imul v9, v14
@515b                               v16 = iconst.i64 32
@515d                               v17 = ushr_imm v5, 32
@5162                               v18 = imul v17, v11
@5163                               v19 = iadd v15, v18
@5166                               v20 = iconst.i64 32
@5168                               v21 = ishl_imm v19, 32
@5169                               v22 = iadd v12, v21
@516c                               v23 = heap_addr.i64 heap0, v2, 1
@516c                               store little v22, v23
@5175                               v24 = imul v17, v14
@517a                               v25 = icmp ult v19, v15
@517a                               v26 = bint.i32 v25
@517b                               v27 = uextend.i64 v26
@517c                               v28 = iconst.i64 32
@517e                               v29 = ishl_imm v27, 32
@5181                               v30 = iconst.i64 32
@5183                               v31 = ushr_imm v19, 32
@5184                               v32 = bor v29, v31
@5185                               v33 = iadd v24, v32
@518a                               v34 = icmp ult v22, v12
@518a                               v35 = bint.i32 v34
@518b                               v36 = uextend.i64 v35
@518c                               v37 = iadd v33, v36
@5191                               v38 = imul v6, v3
@5196                               v39 = imul v5, v4
@5197                               v40 = iadd v38, v39
@5198                               v41 = iadd v37, v40
@5199                               v42 = heap_addr.i64 heap0, v2, 1
@5199                               store little v41, v42+8
@519c                               jump block1

                                block1:
@519c                               return

into

block0(v0: i64, v1: i64, v2: i32, v3: i64, v4: i64, v5: i64, v6: i64):
                                    v43 = iconcat v3, v4
                                    v44 = iconcat v5, v6
                                    v45 = imul v43, v44
                                    v46, v47 = isplit v45
@5130                               v7 = iconst.i64 0
@5136                               v8 = iconst.i64 0xffff_ffff
@513c                               v9 = nop 
@5141                               v10 = iconst.i64 0xffff_ffff
@5147                               v11 = nop 
@514a                               v12 = nop 
@5151                               v13 = iconst.i64 32
@5153                               v14 = nop 
@5156                               v15 = nop 
@515b                               v16 = iconst.i64 32
@515d                               v17 = nop 
@5162                               v18 = nop 
@5163                               v19 = nop 
@5166                               v20 = iconst.i64 32
@5168                               v21 = nop 
@5169                               v22 = nop 
@516c                               v23 = heap_addr.i64 heap0, v2, 1
@516c                               store little v46, v23
@5175                               v24 = nop 
@517a                               v25 = nop 
@517a                               v26 = nop 
@517b                               v27 = nop 
@517c                               v28 = iconst.i64 32
@517e                               v29 = nop 
@5181                               v30 = iconst.i64 32
@5183                               v31 = nop 
@5184                               v32 = nop 
@5185                               v33 = nop 
@518a                               v34 = nop 
@518a                               v35 = nop 
@518b                               v36 = nop 
@518c                               v37 = nop 
@5191                               v38 = nop 
@5196                               v39 = nop 
@5197                               v40 = nop 
@5198                               v41 = nop 
@5199                               v42 = heap_addr.i64 heap0, v2, 1
@5199                               store little v47, v42+8
@519c                               jump block1

                                block1:
@519c                               return

Very naive benchmark gives around 30% speedup

@shamatar
Copy link
Author

shamatar commented May 7, 2022

Few comments:

  • is there any pass that removes unused values? I believe it should be, but I'd like to trigger it manually
  • it may be a better options if instead of concat - imul - split I'd try manual implementation using imul/umulhi
  • is there any inlining step? In my benchmarks I've tried when the main MAC arithmetic function is "extern C" and normal Rust with "inline" attribute and there is no difference, that most likely indicated that "__multi3" is inserted by the compiler backend at the later stages. And after I substantially reduce the body of "__multi3" it should be a good candidate for inlining, and it will also give me more information to work with, like the fact that parts of the input are always 0

@bjorn3
Copy link
Contributor

bjorn3 commented May 7, 2022

is there any pass that removes unused values? I believe it should be, but I'd like to trigger it manually

Yes, the dce pass. Note that lowering to the backend specific ir already does this implicitly. Also I think you shouldn't replace the instructions with nops. It results in invalid clif ir (as nop doesn't return any values) and it is incorrect if any of the instruction return values are used elsewhere. For example because you matched a function that looks somewhat like __multi3 but not exactly.

is there any inlining step?

No, there isn't. It is also non-trivial to implement as currently every function is independently lowered to clif ir, optimized and codegened to native code. Wasmtime even compiles multiple functions in parallel.

@shamatar
Copy link
Author

shamatar commented May 7, 2022

Ty @bjorn3

I've analyzed my NOPing approach and it's overzealous indeed.

As for inlining - are there any plans to add it? It would require some form of "synchronization point" after CLIF generation, then inlining itself can also be parallelized, and then it's again a parallel native codegen again

@bjorn3
Copy link
Contributor

bjorn3 commented May 7, 2022

@cfallin cfallin added enhancement cranelift Issues related to the Cranelift code generator labels May 10, 2022
@jameysharp
Copy link
Contributor

I've been investigating what we can do to improve the performance of __multi3. For reference, the C source for this function is in LLVM at compiler-rt/lib/builtins/multi3.c, which helps a little with understanding what's going on.

Part of the solution proposed in earlier discussion here is function inlining. @elliottt and I discussed this some yesterday and what I'd like to see is a separate tool that supports an inlining transformation on core WebAssembly. I'd then recommend that people use that to eliminate the optimization barriers around __multi3. I don't see inlining actually happening soon in Wasmtime itself and it'd be nice to have something to recommend for that.

One difficult problem is that this function returns a two-member struct, which according to the WebAssembly basic C ABI means it must be returned through linear memory. If it were returned on the wasm stack we'd be able to keep both parts in registers but instead we have to write to RAM. There is no optimization we can legally do to avoid the writes, even with inlining, although with inlining we likely can avoid reading the values back from RAM. Ideally the C ABI would be updated to take advantage of multi-value returns, which apparently didn't exist when it was specified.

Now, what is this function actually computing, and what patterns can we recognize and transform in the mid-end?

Raw CLIF for __multi3, slightly edited
test optimize precise-output
set opt_level=speed_and_size
target x86_64

; v3  = a_lo
; v4  = a_hi
; v5  = b_lo
; v6  = b_hi
; v11 = a0
; v14 = a1
; v9  = b0
; v17 = b1
function %multi3(i64 vmctx, i64, i32, i64, i64, i64, i64) fast {
    gv0 = vmctx
    gv1 = load.i64 notrap aligned readonly gv0+80

block0(v0: i64, v1: i64, v2: i32, v3: i64, v4: i64, v5: i64, v6: i64):
    v7 = iconst.i64 0
    v8 = iconst.i64 0xffff_ffff
    v9 = band v5, v8
    v10 = iconst.i64 0xffff_ffff
    v11 = band v3, v10
    v12 = imul v9, v11
    v13 = iconst.i64 32
    v14 = ushr v3, v13
    v15 = imul v9, v14
    v16 = iconst.i64 32
    v17 = ushr v5, v16
    v18 = imul v17, v11
    v19 = iadd v15, v18
    v20 = iconst.i64 32
    v21 = ishl v19, v20
    v22 = iadd v12, v21
    v23 = uextend.i64 v2
    v24 = global_value.i64 gv1
    v25 = iadd v24, v23
    store little heap v22, v25
    v26 = imul v17, v14
    v27 = icmp ult v19, v15
    v28 = uextend.i32 v27
    v29 = uextend.i64 v28
    v30 = iconst.i64 32
    v31 = ishl v29, v30
    v32 = iconst.i64 32
    v33 = ushr v19, v32
    v34 = bor v31, v33
    v35 = iadd v26, v34
    v36 = icmp ult v22, v12
    v37 = uextend.i32 v36
    v38 = uextend.i64 v37
    v39 = iadd v35, v38
    v40 = imul v6, v3
    v41 = imul v5, v4
    v42 = iadd v40, v41
    v43 = iadd v39, v42
    v44 = uextend.i64 v2
    v45 = global_value.i64 gv1
    v46 = iadd v45, v44
    v47 = iadd_imm v46, 8
    store little heap v43, v47
    return
}

This function implements a 128-bit multiply where the operands and result are stored as pairs of 64-bit integers. It builds up the result, in part, using a series of 32x32->64 multiplies. Viewed as a sequence of 32-bit "digits", the result of long multiplication should look like this, although each single-digit product may have a carry into the column to its left:

     a3    a2    a1    a0
   * b3    b2    b1    b0
  =======================
  a0*b3 a0*b2 a0*b1 a0*b0
+ a1*b2 a1*b1 a1*b0
+ a2*b1 a2*b0
+ a3*b0

However only a1/a0 and b1/b0 are actually treated this way, to compute the lower half of the result along with part of the upper half.

                 a1    a0
   *             b1    b0
  =======================
              a0*b1 a0*b0
+       a1*b1 a1*b0

In Cranelift, we'd want to implement this with a pair of instructions: imul and umulhi, to produce the lower and upper 64 bits, respectively.

The remaining part of the upper half is just a_lo*b_hi+a_hi*b_lo, performed as regular 64-bit multiplies which are equivalent to these parts of the 32-bit-at-a-time long multiplication:

  a0*b3 a0*b2
+ a1*b2
+ a2*b1 a2*b0
+ a3*b0

There's actually nothing we can improve in this part, as shown in @cfallin's gcc-generated assembly listing, which has two imulq and two addq.

So back to the lower-half 64x64->128 multiply, performed 32 bits at a time. Our current optimizer produces this sequence which is equivalent to v22 = imul.i64 v3, v5:

    v8 = iconst.i64 0xffff_ffff
    v9 = band v5, v8  ; v8 = 0xffff_ffff
    v11 = band v3, v8  ; v8 = 0xffff_ffff
    v12 = imul v9, v11
    v13 = iconst.i64 32
    v14 = ushr v3, v13  ; v13 = 32
    v15 = imul v9, v14
    v17 = ushr v5, v13  ; v13 = 32
    v18 = imul v17, v11
    v19 = iadd v15, v18
    v21 = ishl v19, v13  ; v13 = 32
    v22 = iadd v12, v21

Then I think this part (which reuses some of the intermediate values above) is equivalent to either umulhi or smulhi. I haven't yet figured out what purpose the icmp instructions serve, although I can see they're effectively checking if the adds overflowed for v22 and v19 above.

    v26 = imul v17, v14
    v27 = icmp ult v19, v15
    v49 = uextend.i64 v27
    v31 = ishl v49, v13  ; v13 = 32
    v33 = ushr v19, v13  ; v13 = 32
    v34 = bor v31, v33
    v35 = iadd v26, v34
    v36 = icmp ult v22, v12
    v51 = uextend.i64 v36
    v39 = iadd v35, v51

If we can get egraph rules to simplify these two sequences to the upper and lower halves of a 64-bit multiply, then once we also solve #5623 we'll get the same sequence of instructions out that gcc produces for 128-bit multiplies.

Alternatively, if we could match on multiple results at once then we could write a rule that matches this combination of three imul, one umulhi, and three iadd and turns them into a single 128-bit imul surrounded by iconcat and isplit. I believe on x64 we already lower that to something equivalent to the gcc-generated sequence.

jameysharp added a commit to jameysharp/wasmtime that referenced this issue May 19, 2024
LLVM's `__multi3` function works by splitting a wide multiplication into
several narrower ones. This optimization recognizes the algebraic
identities involved and merges them back into the original wide
multiply.

This is not yet done but illustrates how part of the optimization can
work, at least.

Currently, the lower half of the result is optimized into a single
`imul` instruction, but most of the intermediate values that are
optimized away there are still used in computing the upper half, so
elaboration brings them back later.

Fixes bytecodealliance#4077
@jameysharp jameysharp linked a pull request May 19, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cranelift Issues related to the Cranelift code generator enhancement
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants