Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions mglearn/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,44 @@ def print_topics(topics, feature_names, sorting, topics_per_chunk=6,
these_topics = topics[i: i + topics_per_chunk]
# maybe we have less than topics_per_chunk left
len_this_chunk = len(these_topics)
# print topic headers
print(("topic {:<8}" * len_this_chunk).format(*these_topics))
print(("-------- {0:<5}" * len_this_chunk).format(""))
#generate list of sorted features and their lengths
row = []
for i in range(n_words):
row.append(feature_names[sorting[these_topics, i]])
topic_words = np.array(row).T
#get max feature length for each topic
max_feat_len = []
for t in topic_words:
max_feat_len.append(len(max(t, key = len)))
#generate space between strings equal to 1+len(longest string in topic)
result = [None]*len(these_topics)*2
result[::2] = these_topics
nums = np.array([(x - 5) for x in max_feat_len])
nums[nums < 0] = 0 #prevents spaces of negative length
result[1::2] = [str(x) for x in nums]
print(("topic {:<{}} " * len_this_chunk).format(*result))

#generate space between strings equal to 1+len(longest string in topic)
result = [None]*len(these_topics)*2
result[::2] = ['']*len(these_topics)
nums = np.array([(x - 8) for x in max_feat_len])
nums[nums < 0] = 0 #prevents spaces of negative length
result[1::2] = [str(x) for x in nums]
print(("-------- {:<{}} " * len_this_chunk).format(*result))

# print top n_words frequent words
for i in range(n_words):
#generate space between strings
result = [None]*len(these_topics)*2
result[::2] = feature_names[sorting[these_topics, i]]
result[1::2] = [str(x+2) for x in max_feat_len]
try:
print(("{:<14}" * len_this_chunk).format(
*feature_names[sorting[these_topics, i]]))
print(("{:<{}}" * len_this_chunk).format(*result))
except:
pass
print("\n")


def get_tree(tree, **kwargs):
try:
# python3
Expand Down