-
-
Notifications
You must be signed in to change notification settings - Fork 50.4k
Refactor memoized knapsack implementation to remove global state #14535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
e33f437
1352249
a3fc426
db897b7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,27 +6,55 @@ | |
| using dynamic programming. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| def mf_knapsack(i, wt, val, j): | ||
| from functools import lru_cache | ||
|
|
||
|
|
||
| def mf_knapsack(i: int, wt: list[int], val: list[int], j: int) -> int: | ||
| """ | ||
| This code involves the concept of memory functions. Here we solve the subproblems | ||
| which are needed unlike the below example | ||
| F is a 2D array with ``-1`` s filled up | ||
| Return the optimal value for the 0/1 knapsack problem using memoization. | ||
|
|
||
| This implementation caches subproblems with ``functools.lru_cache`` and avoids | ||
| global mutable state. | ||
|
|
||
| >>> mf_knapsack(4, [4, 3, 2, 3], [3, 2, 4, 4], 6) | ||
| 8 | ||
| >>> mf_knapsack(3, [10, 20, 30], [60, 100, 120], 50) | ||
| 220 | ||
| >>> mf_knapsack(0, [1], [10], 50) | ||
| 0 | ||
| """ | ||
| global f # a global dp table for knapsack | ||
| if f[i][j] < 0: | ||
| if j < wt[i - 1]: | ||
| val = mf_knapsack(i - 1, wt, val, j) | ||
| else: | ||
| val = max( | ||
| mf_knapsack(i - 1, wt, val, j), | ||
| mf_knapsack(i - 1, wt, val, j - wt[i - 1]) + val[i - 1], | ||
| ) | ||
| f[i][j] = val | ||
| return f[i][j] | ||
| if i < 0: | ||
| raise ValueError("The number of items to consider cannot be negative.") | ||
| if j < 0: | ||
| raise ValueError("The knapsack capacity cannot be negative.") | ||
| if len(wt) != len(val): | ||
| raise ValueError("The number of weights must match the number of values.") | ||
| if i > len(wt): | ||
| raise ValueError("The number of items to consider cannot exceed input length.") | ||
|
|
||
| weights = tuple(wt) | ||
| values = tuple(val) | ||
|
|
||
| @lru_cache(maxsize=None) | ||
| def solve(item_count: int, capacity: int) -> int: | ||
| if item_count == 0 or capacity == 0: | ||
| return 0 | ||
| if weights[item_count - 1] > capacity: | ||
| return solve(item_count - 1, capacity) | ||
| return max( | ||
| solve(item_count - 1, capacity), | ||
| solve(item_count - 1, capacity - weights[item_count - 1]) | ||
| + values[item_count - 1], | ||
| ) | ||
|
|
||
| return solve(i, j) | ||
|
|
||
|
|
||
| def knapsack(w, wt, val, n): | ||
| def knapsack( | ||
| w: int, wt: list[int], val: list[int], n: int | ||
| ) -> tuple[int, list[list[int]]]: | ||
| dp = [[0] * (w + 1) for _ in range(n + 1)] | ||
|
|
||
| for i in range(1, n + 1): | ||
|
|
@@ -36,10 +64,10 @@ | |
| else: | ||
| dp[i][w_] = dp[i - 1][w_] | ||
|
|
||
| return dp[n][w_], dp | ||
| return dp[n][w], dp | ||
|
|
||
|
||
|
|
||
| def knapsack_with_example_solution(w: int, wt: list, val: list): | ||
| def knapsack_with_example_solution(w: int, wt: list, val: list) -> tuple[int, set[int]]: | ||
| """ | ||
| Solves the integer weights knapsack problem returns one of | ||
| the several possible optimal subsets. | ||
|
|
@@ -100,7 +128,9 @@ | |
| return optimal_val, example_optional_set | ||
|
|
||
|
|
||
| def _construct_solution(dp: list, wt: list, i: int, j: int, optimal_set: set): | ||
| def _construct_solution( | ||
| dp: list[list[int]], wt: list[int], i: int, j: int, optimal_set: set[int] | ||
| ) -> None: | ||
| """ | ||
| Recursively reconstructs one of the optimal subsets given | ||
| a filled DP table and the vector of weights | ||
|
|
@@ -139,10 +169,9 @@ | |
| wt = [4, 3, 2, 3] | ||
| n = 4 | ||
| w = 6 | ||
| f = [[0] * (w + 1)] + [[0] + [-1] * (w + 1) for _ in range(n + 1)] | ||
| optimal_solution, _ = knapsack(w, wt, val, n) | ||
| print(optimal_solution) | ||
| print(mf_knapsack(n, wt, val, w)) # switched the n and w | ||
| print(mf_knapsack(n, wt, val, w)) | ||
|
|
||
| # testing the dynamic programming problem with example | ||
| # the optimal subset for the above example are items 3 and 4 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
mf_knapsack()annotations are more restrictive than the implementation: it only relies onlen()+ indexing and immediately convertswt/valto tuples, so tuples (and otherSequencetypes) work too. Consider widening the types tocollections.abc.Sequence[int]to match actual supported inputs (this pattern is already used elsewhere, e.g.dynamic_programming/max_subarray_sum.py:16).