Skip to content

Commit 2cb1de0

Browse files
update readme
1 parent bd5321b commit 2cb1de0

File tree

2 files changed

+109
-1
lines changed

2 files changed

+109
-1
lines changed

README.md

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,32 @@
11
# torchmsat
22

3-
`pip install torchmsat`
3+
4+
## Installation
5+
6+
`pip install torchmsat`
7+
8+
## Usage
9+
10+
Clauses are in the [WDIMACS Input format](http://www.maxhs.org/docs/wdimacs.html).
11+
12+
```Python
13+
from torchmsat import solver
14+
15+
nv = 2
16+
clauses = [[1, 2],
17+
[1, -2],
18+
[-1, 2],
19+
[-1, -2]]
20+
21+
s = solver.Solver(prob.nv, prob.clauses)
22+
cost, sol = s.compute()
23+
```
24+
25+
Output:
26+
27+
```Python
28+
(1, tensor([[-1., 1., 1., -1., 1., -1.]]))
29+
```
30+
31+
where `1` is the minimum number of unsatisfied clauses, and the tensor represents literal assignments `-1` for `0` and `1` for `1`.
32+

torchmsat/solver.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import signal
2+
import torch
3+
4+
5+
class _Model_(torch.nn.Module):
6+
def __init__(self, nv, clauses) -> None:
7+
super(_Model_, self).__init__()
8+
9+
self.e = torch.ones((1, nv))
10+
x = torch.rand((1, nv))
11+
self.x = torch.nn.Parameter(x)
12+
13+
self.W = torch.zeros((len(clauses), nv))
14+
15+
self.target = torch.zeros((1, len(clauses)))
16+
17+
self.SAT = torch.zeros((1, len(clauses)))
18+
for i, clause in enumerate(clauses):
19+
for literal in clause:
20+
value = 1.0 if literal > 0 else -1.0
21+
literal_idx = abs(literal) - 1
22+
self.W[i, literal_idx] = value
23+
self.SAT[0, i] = -len(clause)
24+
25+
# Auxiliary for reporting a solution
26+
self.sol = torch.zeros_like(self.x)
27+
28+
def forward(self):
29+
act = torch.tanh(self.e*self.x) @ self.W.T
30+
self.sol[self.x > 0] = 1.0
31+
self.sol[self.x <= 0] = -1.0
32+
return act
33+
34+
def sat(self):
35+
unsat_clauses = (self.sol @ self.W.T) == self.SAT
36+
cost = torch.sum(unsat_clauses).item()
37+
return cost
38+
39+
40+
def __str__(self) -> str:
41+
return f'W={self.W}'
42+
43+
44+
class Solver():
45+
def __init__(self, nv, clauses) -> None:
46+
signal.signal(signal.SIGINT, self.signal_handler)
47+
48+
self.trace = {
49+
'start_time': 0.0,
50+
'nn_build_time': 0.0,
51+
'max_sat_time': 0.0,
52+
'nv': nv,
53+
'nc': len(clauses)
54+
}
55+
self.sols = []
56+
57+
self.model = _Model_(nv, clauses)
58+
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-2)
59+
self.loss = torch.nn.MSELoss()
60+
61+
def compute(self):
62+
for i in range(1000):
63+
self.optimizer.zero_grad()
64+
out = self.model()
65+
output = self.loss(out, self.model.target)
66+
output.backward()
67+
self.optimizer.step()
68+
69+
self.sols.append((self.model.sat(), self.model.sol))
70+
71+
return self.max_sat()
72+
73+
def max_sat(self):
74+
max_sat = min(self.sols, key=lambda sol: sol[0])
75+
return max_sat # returns (cost, assignment)
76+
77+
def signal_handler(self, sig, frame):
78+
print(self.max_sat())
79+

0 commit comments

Comments
 (0)