Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 57 additions & 25 deletions poly_matrix/poly_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,45 +568,77 @@ def get_matrix_sparse(self, variables=None, output_type="coo", verbose=False):
indices_i = generate_indices(variable_dict["i"])
indices_j = generate_indices(variable_dict["j"])

i_list = []
j_list = []
data_list = []
variable_dict_i = variable_dict["i"]
variable_dict_j = variable_dict["j"]
requested_i = set(variable_dict_i)
requested_j = set(variable_dict_j)

i_chunks = []
j_chunks = []
data_chunks = []

# Loop through stored blocks and keep only requested ones.
for key_i, row_blocks in self.matrix.items():
if key_i not in requested_i:
continue

# Loop through blocks of stored matrices
for key_i in variable_dict["i"]:
for key_j in variable_dict["j"]:
try:
values = self.matrix[key_i][key_j]
except KeyError:
row_offset = indices_i[key_i]
size_i = variable_dict_i[key_i]

for key_j, values in row_blocks.items():
if key_j not in requested_j:
continue
Comment on lines +580 to 590

size_j = variable_dict_j[key_j]
# Check if blocks appear in variable dictionary
assert values.shape == (
variable_dict["i"][key_i],
variable_dict["j"][key_j],
), f"Variable size does not match input matrix size, variables: {(variable_dict['i'][key_i], variable_dict['j'][key_j])}, matrix: {values.shape}"
# generate list of indices for sparse mat input
size_i,
size_j,
), f"Variable size does not match input matrix size, variables: {(size_i, size_j)}, matrix: {values.shape}"

col_offset = indices_j[key_j]

# generate index/value arrays for sparse mat input
if sp.issparse(values):
rows, cols = values.nonzero()
data_list += list(values.data)
block = values.tocoo()
if block.nnz == 0:
continue
rows = block.row + row_offset
cols = block.col + col_offset
data = block.data
Comment on lines 602 to +608
else:
rows, cols = np.nonzero(values)
data_list += list(values[rows, cols])
i_list += list(rows + indices_i[key_i])
j_list += list(cols + indices_j[key_j])
local_rows, local_cols = np.nonzero(values)
if local_rows.size == 0:
continue
rows = local_rows + row_offset
cols = local_cols + col_offset
data = values[local_rows, local_cols]

i_chunks.append(rows)
j_chunks.append(cols)
data_chunks.append(data)

shape = get_shape(variable_dict_i, variable_dict_j)
if data_chunks:
i_data = np.concatenate(i_chunks)
j_data = np.concatenate(j_chunks)
values_data = np.concatenate(data_chunks)
else:
i_data = np.empty(0, dtype=int)
j_data = np.empty(0, dtype=int)
values_data = np.empty(0, dtype=float)

shape = get_shape(variable_dict["i"], variable_dict["j"])
mat = sp.coo_matrix((values_data, (i_data, j_data)), shape=shape)

if output_type == "coo":
mat = sp.coo_matrix((data_list, (i_list, j_list)), shape=shape)
return mat
elif output_type == "csr":
mat = sp.csr_matrix((data_list, (i_list, j_list)), shape=shape)
return mat.tocsr()
elif output_type == "csc":
mat = sp.csc_matrix((data_list, (i_list, j_list)), shape=shape)
return mat.tocsc()
else:
raise ValueError(f"Unknown matrix type {output_type}")

return mat

def get_vector(self, variables=None, **kwargs):
"""

Expand Down
27 changes: 27 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
[build-system]
requires = ["setuptools>=61", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "poly_matrix"
version = "0.3.1"
description = "Code for efficient handling of sparse matrices"
readme = "README.md"
requires-python = ">=3.8"
license = { file = "LICENSE" }
authors = [
{ name = "Frederike Dümbgen", email = "frederike.dumbgen@utoronto.ca" },
{ name = "Connor Holmes", email = "connor.holmes@mail.utoronto.ca" }
]
dependencies = [
"pandas>=1.5",
"numpy>=1.23",
"scipy>=1.9",
"matplotlib>=3.6"
]

[project.urls]
Homepage = "https://github.com/utiasASRL/poly_matrix.git"

Comment on lines +5 to +25
[tool.setuptools.packages.find]
exclude = ["_test*"]
Loading