diff --git a/sknw/sknw.py b/sknw/sknw.py index bb5e838..99bc1f7 100644 --- a/sknw/sknw.py +++ b/sknw/sknw.py @@ -25,7 +25,7 @@ def mark(img, nbs): # mark the array use (0, 1, 2) @jit(nopython=True) # trans index to r, c... def idx2rc(idx, acc): - rst = np.zeros((len(idx), len(acc)), dtype=np.int16) + rst = np.zeros((len(idx), len(acc)), dtype=np.int64) for i in range(len(idx)): for j in range(len(acc)): rst[i,j] = idx[i]//acc[j] @@ -57,6 +57,8 @@ def trace(img, p, nbs, acc, buf): c1 = 0; c2 = 0; newp = 0 cur = 1 + buf_len = len(buf) + buf_overflow = False while True: buf[cur] = p img[p] = 0 @@ -74,12 +76,15 @@ def trace(img, p, nbs, acc, buf): newp = cp p = newp if c2!=0:break - return (c1-10, c2-10, idx2rc(buf[:cur+1], acc)) + if cur >= buf_len: + buf_overflow = True + break + return (c1-10, c2-10, idx2rc(buf[:cur+1], acc)), buf_overflow @jit(nopython=True) # parse the image then get the nodes and edges -def parse_struc(img, nbs, acc, iso, ring): +def parse_struc(img, nbs, acc, iso, ring, buf_size): img = img.ravel() - buf = np.zeros(131072, dtype=np.int64) + buf = np.zeros(buf_size, dtype=np.int64) num = 10 nodes = [] for p in range(len(img)): @@ -93,23 +98,27 @@ def parse_struc(img, nbs, acc, iso, ring): if img[p] <10: continue for dp in nbs: if img[p+dp]==1: - edge = trace(img, p+dp, nbs, acc, buf) + edge, buf_overflow = trace(img, p+dp, nbs, acc, buf) + if buf_overflow: + return nodes, edges, buf_overflow edges.append(edge) - if not ring: return nodes, edges + if not ring: return nodes, edges, False for p in range(len(img)): if img[p]!=1: continue img[p] = num; num += 1 nodes.append(idx2rc([p], acc)) for dp in nbs: if img[p+dp]==1: - edge = trace(img, p+dp, nbs, acc, buf) + edge, buf_overflow = trace(img, p+dp, nbs, acc, buf) + if buf_overflow: + return nodes, edges, buf_overflow edges.append(edge) - return nodes, edges + return nodes, edges, False # use nodes and edges build a networkx graph def build_graph(nodes, edges, multi=False, full=True): os = np.array([i.mean(axis=0) for i in nodes]) - if full: os = os.round().astype(np.uint16) + if full: os = os.round().astype(np.uint32) graph = nx.MultiGraph() if multi else nx.Graph() for i in range(len(nodes)): graph.add_node(i, pts=nodes[i], o=os[i]) @@ -126,12 +135,14 @@ def mark_node(ske): mark(buf, nbs) return buf -def build_sknw(ske, multi=False, iso=True, ring=True, full=True): - buf = np.pad(ske, (1,1), mode='constant').astype(np.uint16) +def build_sknw(ske, multi=False, iso=True, ring=True, full=True, buf_size=131072): + buf = np.pad(ske, (1,1), mode='constant').astype(np.int64) nbs = neighbors(buf.shape) acc = np.cumprod((1,)+buf.shape[::-1][:-1])[::-1] mark(buf, nbs) - nodes, edges = parse_struc(buf, nbs, acc, iso, ring) + nodes, edges, buf_overflow = parse_struc(buf, nbs, acc, iso, ring, buf_size) + if buf_overflow: + raise Exception('Buffer overflow') return build_graph(nodes, edges, multi, full) # draw the graph