Skip to content

Commit

Permalink
fix: replace np.product with np.prod to enable use with newer numpy v…
Browse files Browse the repository at this point in the history
…ersions
  • Loading branch information
osbm committed Jan 19, 2025
1 parent b661598 commit 05af9fa
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
8 changes: 4 additions & 4 deletions banding_removal/fastmri/data/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


def create_input(shape):
input = np.arange(np.product(shape)).reshape(shape)
input = np.arange(np.prod(shape)).reshape(shape)
input = torch.from_numpy(input).float()
return input

Expand Down Expand Up @@ -179,7 +179,7 @@ def test_normalize_instance(shape):
[3, 4, 5],
])
def test_roll(shift, dim, shape):
input = np.arange(np.product(shape)).reshape(shape)
input = np.arange(np.prod(shape)).reshape(shape)
out_torch = transforms.roll(torch.from_numpy(input), shift, dim).numpy()
out_numpy = np.roll(input, shift, dim)
assert np.allclose(out_torch, out_numpy)
Expand All @@ -190,7 +190,7 @@ def test_roll(shift, dim, shape):
[2, 4, 6],
])
def test_fftshift(shape):
input = np.arange(np.product(shape)).reshape(shape)
input = np.arange(np.prod(shape)).reshape(shape)
out_torch = transforms.fftshift(torch.from_numpy(input)).numpy()
out_numpy = np.fft.fftshift(input)
assert np.allclose(out_torch, out_numpy)
Expand All @@ -202,7 +202,7 @@ def test_fftshift(shape):
[2, 7, 5],
])
def test_ifftshift(shape):
input = np.arange(np.product(shape)).reshape(shape)
input = np.arange(np.prod(shape)).reshape(shape)
out_torch = transforms.ifftshift(torch.from_numpy(input)).numpy()
out_numpy = np.fft.ifftshift(input)
assert np.allclose(out_torch, out_numpy)
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


def create_input(shape):
x = np.arange(np.product(shape)).reshape(shape)
x = np.arange(np.prod(shape)).reshape(shape)
x = torch.from_numpy(x).float()

return x
Expand Down
6 changes: 3 additions & 3 deletions tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_root_sum_of_squares(shape, dim):
],
)
def test_roll(shift, dim, shape):
x = np.arange(np.product(shape)).reshape(shape)
x = np.arange(np.prod(shape)).reshape(shape)
if isinstance(shift, int) and isinstance(dim, int):
torch_shift = [shift]
torch_dim = [dim]
Expand All @@ -132,7 +132,7 @@ def test_roll(shift, dim, shape):
],
)
def test_fftshift(shape):
x = np.arange(np.product(shape)).reshape(shape)
x = np.arange(np.prod(shape)).reshape(shape)
out_torch = fastmri.fftshift(torch.from_numpy(x)).numpy()
out_numpy = np.fft.fftshift(x)

Expand All @@ -148,7 +148,7 @@ def test_fftshift(shape):
],
)
def test_ifftshift(shape):
x = np.arange(np.product(shape)).reshape(shape)
x = np.arange(np.prod(shape)).reshape(shape)
out_torch = fastmri.ifftshift(torch.from_numpy(x)).numpy()
out_numpy = np.fft.ifftshift(x)

Expand Down

0 comments on commit 05af9fa

Please sign in to comment.