Skip to content

Commit

Permalink
Perform bounds checks on bulk memory operations
Browse files Browse the repository at this point in the history
Now, all memory operations are guarded by bounds checks and don't need
to be guarded at the call site.

Out-of-bounds memory accesses no longer invalidate compiled code.
  • Loading branch information
jirkamarsik committed Jan 30, 2025
1 parent 4136504 commit f2067db
Show file tree
Hide file tree
Showing 12 changed files with 184 additions and 159 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2022, 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* The Universal Permissive License (UPL), Version 1.0
Expand Down Expand Up @@ -696,7 +696,7 @@ public void testMemoryInitInvalidDestinationOffset() throws IOException {
main.execute();
Assert.fail("Should have thrown");
} catch (PolyglotException e) {
Assert.assertTrue("Expected out of bounds error", e.getMessage().contains("out of bounds memory access"));
Assert.assertTrue("Expected out of bounds error", e.getMessage().contains("out-of-bounds") && e.getMessage().contains("memory access"));
}
});
}
Expand Down Expand Up @@ -741,7 +741,7 @@ public void testMemoryCopyInvalidLength() throws IOException {
main.execute();
Assert.fail("Should have thrown");
} catch (PolyglotException e) {
Assert.assertTrue("Expected out of bounds error", e.getMessage().contains("out of bounds memory access"));
Assert.assertTrue("Expected out of bounds error", e.getMessage().contains("out-of-bounds") && e.getMessage().contains("memory access"));
}
});
}
Expand All @@ -764,7 +764,7 @@ public void testMemoryCopyInvalidSourceOffset() throws IOException {
main.execute();
Assert.fail("Should have thrown");
} catch (PolyglotException e) {
Assert.assertTrue("Expected out of bounds error", e.getMessage().contains("out of bounds memory access"));
Assert.assertTrue("Expected out of bounds error", e.getMessage().contains("out-of-bounds") && e.getMessage().contains("memory access"));
}
});
}
Expand All @@ -787,7 +787,7 @@ public void testMemoryCopyInvalidDestinationOffset() throws IOException {
main.execute();
Assert.fail("Should have thrown");
} catch (PolyglotException e) {
Assert.assertTrue("Expected out of bounds error", e.getMessage().contains("out of bounds memory access"));
Assert.assertTrue("Expected out of bounds error", e.getMessage().contains("out-of-bounds") && e.getMessage().contains("memory access"));
}
});
}
Expand Down Expand Up @@ -825,7 +825,7 @@ public void testMemoryFillInvalidLength() throws IOException {
main.execute();
Assert.fail("Should have thrown");
} catch (PolyglotException e) {
Assert.assertTrue("Expected out of bounds error", e.getMessage().contains("out of bounds memory access"));
Assert.assertTrue("Expected out of bounds error", e.getMessage().contains("out-of-bounds") && e.getMessage().contains("memory access"));
}
});
}
Expand All @@ -846,7 +846,7 @@ public void testMemoryFillInvalidDestinationOffset() throws IOException {
main.execute();
Assert.fail("Should have thrown");
} catch (PolyglotException e) {
Assert.assertTrue("Expected out of bounds error", e.getMessage().contains("out of bounds memory access"));
Assert.assertTrue("Expected out of bounds error", e.getMessage().contains("out-of-bounds") && e.getMessage().contains("memory access"));
}
});
}
Expand Down
4 changes: 1 addition & 3 deletions wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java
Original file line number Diff line number Diff line change
Expand Up @@ -765,10 +765,8 @@ void resolveDataSegment(WasmContext context, WasmInstance instance, int dataSegm
}

WasmMemoryLibrary memoryLib = WasmMemoryLibrary.getUncached();
Assert.assertUnsignedLongLessOrEqual(baseAddress, memoryLib.byteSize(memory), Failure.OUT_OF_BOUNDS_MEMORY_ACCESS);
Assert.assertUnsignedLongLessOrEqual(baseAddress + byteLength, memoryLib.byteSize(memory), Failure.OUT_OF_BOUNDS_MEMORY_ACCESS);
final byte[] bytecode = instance.module().bytecode();
memoryLib.initialize(memory, bytecode, bytecodeOffset, baseAddress, byteLength);
memoryLib.initialize(memory, null, bytecode, bytecodeOffset, baseAddress, byteLength);
instance.setDataInstance(dataSegmentId, droppedDataInstanceOffset);
};
final ArrayList<Sym> dependencies = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,14 @@ public void reset() {
currentMinSize = declaredMinSize;
}

private void validateAddress(Node node, long address, long length) {
validateAddress(node, address, length, byteSize());
}

private WasmException trapOutOfBounds(Node node, long address, long length) {
return trapOutOfBounds(node, address, length, byteSize());
}

// Checkstyle: stop
@ExportMessage
public int load_i32(Node node, long address) {
Expand Down Expand Up @@ -349,13 +357,6 @@ public void store_i128(Node node, long address, Vector128 value) {
}
}

private static void validateAtomicAddress(Node node, long address, int length) {
if ((address & (length - 1)) != 0) {
CompilerDirectives.transferToInterpreterAndInvalidate();
throw trapUnalignedAtomic(node, address, length);
}
}

@ExportMessage
public int atomic_load_i32(Node node, long address) {
validateAtomicAddress(node, address, 4);
Expand Down Expand Up @@ -989,10 +990,8 @@ public long atomic_rmw_cmpxchg_i64(Node node, long address, long expected, long
@ExportMessage
@TruffleBoundary
public int atomic_notify(Node node, long address, int count) {
validateAddress(node, address, 4);
validateAtomicAddress(node, address, 4);
if (outOfBounds(address, 4)) {
throw trapOutOfBounds(node, address, 4);
}
if (!this.isShared()) {
return 0;
}
Expand All @@ -1002,10 +1001,8 @@ public int atomic_notify(Node node, long address, int count) {
@ExportMessage
@TruffleBoundary
public int atomic_wait32(Node node, long address, int expected, long timeout) {
validateAddress(node, address, 4);
validateAtomicAddress(node, address, 4);
if (outOfBounds(address, 4)) {
throw trapOutOfBounds(node, address, 4);
}
if (!this.isShared()) {
throw trapUnsharedMemory(node);
}
Expand All @@ -1015,10 +1012,8 @@ public int atomic_wait32(Node node, long address, int expected, long timeout) {
@ExportMessage
@TruffleBoundary
public int atomic_wait64(Node node, long address, long expected, long timeout) {
validateAddress(node, address, 4, 8);
validateAtomicAddress(node, address, 8);
if (outOfBounds(address, 8)) {
throw trapOutOfBounds(node, address, 8);
}
if (!this.isShared()) {
throw trapUnsharedMemory(node);
}
Expand All @@ -1027,23 +1022,31 @@ public int atomic_wait64(Node node, long address, long expected, long timeout) {
// Checkstyle: resume

@ExportMessage
public void initialize(byte[] source, int sourceOffset, long destinationOffset, int length) {
assert destinationOffset + length <= byteSize();
public void initialize(Node node, byte[] source, int sourceOffset, long destinationOffset, int length) {
validateLength(node, length);
validateAddress(node, destinationOffset, length);
if (sourceOffset < 0 || sourceOffset > source.length - length) {
CompilerDirectives.transferToInterpreterAndInvalidate();
throw trapOutOfBoundsBuffer(node, sourceOffset, length, source.length);
}
System.arraycopy(source, sourceOffset, buffer(), (int) destinationOffset, length);
}

@ExportMessage
@TruffleBoundary
public void fill(long offset, long length, byte value) {
assert offset + length <= byteSize();
public void fill(Node node, long offset, long length, byte value) {
validateLength(node, length);
validateAddress(node, offset, length);
Arrays.fill(buffer(), (int) offset, (int) (offset + length), value);
}

@ExportMessage
public void copyFrom(WasmMemory source, long sourceOffset, long destinationOffset, long length) {
public void copyFrom(Node node, WasmMemory source, long sourceOffset, long destinationOffset, long length) {
assert source instanceof ByteArrayWasmMemory;
assert destinationOffset < byteSize();
ByteArrayWasmMemory s = (ByteArrayWasmMemory) source;
validateLength(node, length);
s.validateAddress(node, sourceOffset, length);
validateAddress(node, destinationOffset, length);
System.arraycopy(s.buffer(), (int) sourceOffset, buffer(), (int) destinationOffset, (int) length);
}

Expand All @@ -1059,36 +1062,29 @@ public void close() {
dynamicBuffer = null;
}

private boolean outOfBounds(int offset, int length) {
return length < 0 || offset < 0 || offset > byteSize() - length;
}

private boolean outOfBounds(long offset, long length) {
return length < 0 || offset < 0 || offset > byteSize() - length;
}

@ExportMessage
@TruffleBoundary
public int copyFromStream(Node node, InputStream stream, int offset, int length) throws IOException {
if (outOfBounds(offset, length)) {
throw trapOutOfBounds(node, offset, length);
}
validateLength(node, length);
validateAddress(node, offset, length);
return stream.read(buffer(), offset, length);
}

@ExportMessage
@TruffleBoundary
public void copyToStream(Node node, OutputStream stream, int offset, int length) throws IOException {
if (outOfBounds(offset, length)) {
throw trapOutOfBounds(node, offset, length);
}
validateLength(node, length);
validateAddress(node, offset, length);
stream.write(buffer(), offset, length);
}

@ExportMessage
public void copyToBuffer(Node node, byte[] dst, long srcOffset, int dstOffset, int length) {
if (outOfBounds(srcOffset, length)) {
throw trapOutOfBounds(node, srcOffset, length);
validateLength(node, length);
validateAddress(node, srcOffset, length);
if (dstOffset < 0 || dstOffset > dst.length - length) {
CompilerDirectives.transferToInterpreterAndInvalidate();
throw trapOutOfBoundsBuffer(node, dstOffset, length, dst.length);
}
System.arraycopy(buffer(), (int) srcOffset, dst, dstOffset, length);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,21 +173,8 @@ public void reset() {
currentMinSize = declaredMinSize;
}

private void validateAddress(Node node, long address, int length) {
assert length >= 1;
long byteSize = byteSize();
assert byteSize >= 0;
if (address < 0 || address > byteSize - length) {
CompilerDirectives.transferToInterpreterAndInvalidate();
throw trapOutOfBounds(node, address, length);
}
}

private static void validateAtomicAddress(Node node, long address, int length) {
if ((address & (length - 1)) != 0) {
CompilerDirectives.transferToInterpreterAndInvalidate();
throw trapUnalignedAtomic(node, address, length);
}
private void validateAddress(Node node, long address, long length) {
validateAddress(node, address, length, byteSize());
}

// Checkstyle: stop
Expand Down Expand Up @@ -966,19 +953,30 @@ public WasmMemory duplicate() {
}

@ExportMessage
public void initialize(byte[] source, int sourceOffset, long destinationOffset, int length) {
public void initialize(Node node, byte[] source, int sourceOffset, long destinationOffset, int length) {
validateLength(node, length);
validateAddress(node, destinationOffset, length);
if (sourceOffset < 0 || sourceOffset > source.length - length) {
CompilerDirectives.transferToInterpreterAndInvalidate();
throw trapOutOfBoundsBuffer(node, sourceOffset, length, source.length);
}
unsafe.copyMemory(source, Unsafe.ARRAY_BYTE_BASE_OFFSET + sourceOffset * Unsafe.ARRAY_BYTE_INDEX_SCALE, null, startAddress + destinationOffset, length);
}

@ExportMessage
public void fill(long offset, long length, byte value) {
public void fill(Node node, long offset, long length, byte value) {
validateLength(node, length);
validateAddress(node, offset, length);
unsafe.setMemory(startAddress + offset, length, value);
}

@ExportMessage
public void copyFrom(WasmMemory source, long sourceOffset, long destinationOffset, long length) {
public void copyFrom(Node node, WasmMemory source, long sourceOffset, long destinationOffset, long length) {
assert source instanceof NativeWasmMemory;
final NativeWasmMemory s = (NativeWasmMemory) source;
validateLength(node, length);
s.validateAddress(node, sourceOffset, length);
validateAddress(node, destinationOffset, length);
unsafe.copyMemory(s.startAddress + sourceOffset, this.startAddress + destinationOffset, length);
}

Expand All @@ -1002,20 +1000,11 @@ public void close() {
}
}

private boolean outOfBounds(int offset, int length) {
return length < 0 || offset < 0 || offset > byteSize() - length;
}

private boolean outOfBounds(long offset, long length) {
return length < 0 || offset < 0 || offset > byteSize() - length;
}

@ExportMessage
@TruffleBoundary
public int copyFromStream(Node node, InputStream stream, int offset, int length) throws IOException {
if (outOfBounds(offset, length)) {
throw trapOutOfBounds(node, offset, length);
}
validateLength(node, length);
validateAddress(node, offset, length);
int totalBytesRead = 0;
for (int i = 0; i < length; i++) {
int byteRead = stream.read();
Expand All @@ -1034,9 +1023,8 @@ public int copyFromStream(Node node, InputStream stream, int offset, int length)
@ExportMessage
@TruffleBoundary
public void copyToStream(Node node, OutputStream stream, int offset, int length) throws IOException {
if (outOfBounds(offset, length)) {
throw trapOutOfBounds(node, offset, length);
}
validateLength(node, length);
validateAddress(node, offset, length);
for (int i = 0; i < length; i++) {
byte b = unsafe.getByte(startAddress + offset + i);
stream.write(b & 0x0000_00ff);
Expand All @@ -1045,8 +1033,11 @@ public void copyToStream(Node node, OutputStream stream, int offset, int length)

@ExportMessage
public void copyToBuffer(Node node, byte[] dst, long srcOffset, int dstOffset, int length) {
if (outOfBounds(srcOffset, length)) {
throw trapOutOfBounds(node, srcOffset, length);
validateLength(node, length);
validateAddress(node, srcOffset, length);
if (dstOffset < 0 || dstOffset > dst.length - length) {
CompilerDirectives.transferToInterpreterAndInvalidate();
throw trapOutOfBoundsBuffer(node, dstOffset, length, dst.length);
}
unsafe.copyMemory(null, startAddress + srcOffset, dst, Unsafe.ARRAY_BYTE_BASE_OFFSET + (long) dstOffset * Unsafe.ARRAY_BYTE_INDEX_SCALE, length);
}
Expand Down
Loading

0 comments on commit f2067db

Please sign in to comment.