Skip to content

Commit fe0f93e

Browse files
authored
Merge pull request #111 from EXP-code/fixKmeans
k-means update for pyEXP
2 parents dccf0b8 + 4f40e7c commit fe0f93e

4 files changed

Lines changed: 231 additions & 74 deletions

File tree

expui/KMeans.cc

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,28 @@ namespace MSSA
1313
//
1414
cen.clear();
1515

16-
if (s>0) { // Seed centers by stride
16+
if (s>0) {
17+
// Seed centers by stride
1718
for (int i=0; i<classes.size(); i+=s) {
1819
if (cen.size()>=k) break;
1920
cen.push_back(classes.at(i)->x);
2021
}
2122
k = cen.size();
22-
} else { // obtain a seed from the system clock
23+
} else {
24+
// Obtain a seed from the system clock
2325
unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
26+
27+
// Make a random generator
2428
std::mt19937 gen(seed);
25-
// Seed centers randomly
29+
30+
// Randomly shuffle a list of indexes
31+
std::vector<int> indx;
32+
for (int i=0; i<classes.size(); i++) indx.push_back(i);
33+
std::shuffle(indx.begin(), indx.end(), gen);
34+
35+
// Seed centers randomly from the initial point list
2636
for (int i=0; i<k; ++i) {
27-
cen.push_back(classes.at(gen() % classes.size())->x);
37+
cen.push_back(classes.at(indx[i])->x);
2838
}
2939
}
3040

expui/expMSSA.H

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,13 +243,39 @@ namespace MSSA
243243
//! Create wcorrlation matricies and output PNG
244244
void wcorrPNG();
245245

246+
//@{
246247
/**
247248
Kmean analysis of the trajectories by PC with fixed cluster size
249+
*/
250+
251+
/** Perform Kmeans analysis for a given number of clusters and
252+
print the results
248253
254+
@param clusters is the maximum number of clusters considered
255+
@param stride is the seed strategy. If positive, it is used to
256+
select initial cluster centers by stride from the PC list. If
257+
it is zero, centers are selected randomly from the PC list
249258
@param toTerm write to stdout if true
250259
@param toFile write to file if true
251260
*/
252-
void kmeans(int clusters, bool toTerm=true, bool toFile=false);
261+
void kmeansPrint(int clusters, int stride,
262+
bool toTerm=true, bool toFile=false);
263+
264+
/** Get Kmeans analysis per channel
265+
266+
@param clusters is the number of clusters to seed
267+
@param key is the channel id vector<int>
268+
*/
269+
std::tuple<std::vector<int>, std::vector<double>, double>
270+
kmeansChannel(Key key, int clusters, int stride);
271+
272+
/** Get Kmeans analysis for all channels
273+
274+
@param clusters is the number of clusters to seed
275+
@param stride is the seeded strategy for initial centers
276+
*/
277+
std::tuple<std::vector<int>, std::vector<double>, double>
278+
kmeans(int clusters, int stride);
253279

254280
//! Save current MSSA state to an HDF5 file with the given prefix
255281
void saveState(const std::string& prefix);

expui/expMSSA.cc

Lines changed: 142 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,10 +1063,10 @@ namespace MSSA {
10631063
#endif
10641064
}
10651065

1066-
void expMSSA::kmeans(int clusters, bool toTerm, bool toFile)
1066+
void expMSSA::kmeansPrint(int clusters, int stride, bool toTerm, bool toFile)
10671067
{
10681068
if (clusters==0) {
1069-
std::cout << "expMSSA::kmeans: you need clusters>0" << std::endl;
1069+
std::cout << "expMSSA::kmeansPrint: you need clusters>0" << std::endl;
10701070
return;
10711071
}
10721072

@@ -1087,41 +1087,22 @@ namespace MSSA {
10871087
std::cerr << "Error opening file <" << filename << ">" << std::endl;
10881088
}
10891089

1090-
// W-correlation-based distance functor
1091-
//
1092-
KMeans::WcorrDistance dist(numT, numW);
1093-
10941090
for (auto u : mean) {
1095-
// Pack point array
1096-
//
1097-
std::vector<KMeans::Ptr> data;
1098-
for (int j=0; j<ncomp; j++) {
1099-
data.push_back(std::make_shared<KMeans::Point>(numT));
1100-
for (int i=0; i<numT; i++) data.back()->x[i] = RC[u.first](i, j);
1101-
}
11021091

1103-
// Initialize k-means routine
1104-
//
1105-
KMeans::kMeansClustering kMeans(data);
1106-
1107-
// Run 100 iterations
1108-
//
1109-
kMeans.iterate(dist, 100, clusters, 2, false);
1110-
1111-
// Retrieve cluster associations
1112-
//
1113-
auto results = kMeans.get_results();
1092+
auto [id, dd, tol] = kmeansChannel(u.first, clusters, stride);
11141093

11151094
// Write to file
11161095
//
11171096
if (out) {
11181097
out << std::string(60, '-') << std::endl
11191098
<< " *** n=" << u.first << std::endl
1099+
<< " *** tol=" << tol << std::endl
11201100
<< std::string(60, '-') << std::endl;
11211101

1122-
for (int j=0; j<results.size(); j++) {
1102+
for (int j=0; j<id.size(); j++) {
11231103
out << std::setw(6) << j
1124-
<< std::setw(12) << std::get<1>(results[j])
1104+
<< std::setw(12) << id[j]
1105+
<< std::setw(16) << dd[j]
11251106
<< std::endl;
11261107
}
11271108
}
@@ -1133,51 +1114,31 @@ namespace MSSA {
11331114
<< " *** n=" << u.first << std::endl
11341115
<< std::string(60, '-') << std::endl;
11351116

1136-
for (int j=0; j<results.size(); j++) {
1137-
std::cout << std::setw(6) << j
1138-
<< std::setw(12) << std::get<1>(results[j])
1117+
for (int j=0; j<id.size(); j++) {
1118+
std::cout << std::setw( 6) << j
1119+
<< std::setw( 9) << id[j]
1120+
<< std::setw(16) << dd[j]
11391121
<< std::endl;
11401122
}
11411123
}
11421124
}
11431125

11441126
if (params["allchan"]) {
11451127

1146-
// Pack point array
1147-
//
1148-
std::vector<KMeans::Ptr> data;
1149-
int sz = mean.size();
1150-
for (int j=0; j<ncomp; j++) {
1151-
data.push_back(std::make_shared<KMeans::Point>(numT*sz));
1152-
int c = 0;
1153-
for (auto u : mean) {
1154-
for (int i=0; i<numT; i++) data.back()->x[c++] = RC[u.first](i, j);
1155-
}
1156-
}
1157-
1158-
// Initialize k-means routine
1159-
//
1160-
KMeans::kMeansClustering kMeans(data);
1161-
1162-
// Run 100 iterations
1163-
//
1164-
KMeans::WcorrDistMulti dist2(numT, numW, sz);
1165-
kMeans.iterate(dist2, 100, clusters, 2, false);
1166-
1167-
// Retrieve cluster associations
1168-
//
1169-
auto results = kMeans.get_results();
1128+
auto [id, dd, tol] = kmeans(clusters, stride);
11701129

11711130
// Write to file
11721131
//
11731132
if (out) {
11741133
out << std::string(60, '-') << std::endl
11751134
<< " *** total" << std::endl
1135+
<< " *** tol=" << tol << std::endl
11761136
<< std::string(60, '-') << std::endl;
11771137

1178-
for (int j=0; j<results.size(); j++) {
1179-
out << std::setw(6) << j
1180-
<< std::setw(9) << std::get<1>(results[j])
1138+
for (int j=0; j<id.size(); j++) {
1139+
out << std::setw( 6) << j
1140+
<< std::setw( 9) << id[j]
1141+
<< std::setw(16) << dd[j]
11811142
<< std::endl;
11821143
}
11831144
}
@@ -1189,9 +1150,10 @@ namespace MSSA {
11891150
<< " *** total" << std::endl
11901151
<< std::string(60, '-') << std::endl;
11911152

1192-
for (int j=0; j<results.size(); j++) {
1193-
std::cout << std::setw(6) << j
1194-
<< std::setw(9) << std::get<1>(results[j])
1153+
for (int j=0; j<id.size(); j++) {
1154+
std::cout << std::setw( 6) << j
1155+
<< std::setw( 9) << id[j]
1156+
<< std::setw(16) << dd[j]
11951157
<< std::endl;
11961158
}
11971159
}
@@ -1203,7 +1165,128 @@ namespace MSSA {
12031165
std::cout << "Bad output stream for <" << filename << ">" << std::endl;
12041166
}
12051167
out.close();
1168+
}
1169+
1170+
1171+
std::tuple<std::vector<int>, std::vector<double>, double>
1172+
expMSSA::kmeansChannel(Key key, int clusters, int stride)
1173+
{
1174+
if (clusters==0) {
1175+
throw std::invalid_argument("expMSSA::kmeansChannel: clusters==0");
1176+
}
1177+
1178+
if (stride<0) {
1179+
throw std::invalid_argument("expMSSA::kmeansChannel: stride must be >= 0");
1180+
}
1181+
1182+
if (mean.find(key) == mean.end()) {
1183+
std::ostringstream sout;
1184+
sout << "expMSSA::kmeansKey: key <" << key << "> not found";
1185+
throw std::invalid_argument(sout.str());
1186+
}
1187+
1188+
KMeans::WcorrDistance dist(numT, numW);
1189+
1190+
// Pack point array
1191+
//
1192+
std::vector<KMeans::Ptr> data;
1193+
for (int j=0; j<ncomp; j++) {
1194+
data.push_back(std::make_shared<KMeans::Point>(numT));
1195+
for (int i=0; i<numT; i++) data.back()->x[i] = RC[key](i, j);
1196+
}
1197+
1198+
// Initialize k-means routine
1199+
//
1200+
KMeans::kMeansClustering kMeans(data);
1201+
1202+
// Run 100 iterations
1203+
//
1204+
kMeans.iterate(dist, 1000, clusters, stride);
1205+
1206+
// Retrieve cluster associations
1207+
//
1208+
auto results = kMeans.get_results();
1209+
auto centers = kMeans.get_cen();
1210+
1211+
// Compute inertia
1212+
//
1213+
auto inertia = [&](int j, int id) -> double {
1214+
auto & cen = centers[id];
1215+
double d = 0.0;
1216+
for (int i=0; i<cen.size(); i++)
1217+
d += (cen[i] - data[j]->x[i])*(cen[i] - data[j]->x[i]);
1218+
return sqrt(d);
1219+
};
1220+
1221+
// Pack return vector
1222+
//
1223+
std::vector<int> retI;
1224+
std::vector<double> retD;
1225+
for (int j=0; j<results.size(); j++) {
1226+
retI.push_back(std::get<1>(results[j]));
1227+
retD.push_back(inertia(j, std::get<1>(results[j])));
1228+
}
1229+
1230+
return {retI, retD, kMeans.getTol()};
1231+
}
1232+
1233+
std::tuple<std::vector<int>, std::vector<double>, double>
1234+
expMSSA::kmeans(int clusters, int stride)
1235+
{
1236+
if (clusters==0) {
1237+
throw std::invalid_argument("expMSSA::kmeans: you need clusters>0");
1238+
}
1239+
1240+
if (stride<0) {
1241+
throw std::invalid_argument("expMSSA::kmeans: stride must be >= 0");
1242+
}
1243+
1244+
// Pack point array
1245+
//
1246+
std::vector<KMeans::Ptr> data;
1247+
int sz = mean.size();
1248+
for (int j=0; j<ncomp; j++) {
1249+
data.push_back(std::make_shared<KMeans::Point>(numT*sz));
1250+
int c = 0;
1251+
for (auto u : mean) {
1252+
for (int i=0; i<numT; i++) data.back()->x[c++] = RC[u.first](i, j);
1253+
}
1254+
}
1255+
1256+
// Initialize k-means routine
1257+
//
1258+
KMeans::kMeansClustering kMeans(data);
1259+
1260+
// Run 100 iterations
1261+
//
1262+
KMeans::WcorrDistMulti dist(numT, numW, sz);
1263+
kMeans.iterate(dist, 1000, clusters, stride);
1264+
1265+
// Retrieve cluster associations
1266+
//
1267+
auto results = kMeans.get_results();
1268+
auto centers = kMeans.get_cen();
1269+
1270+
// Compute inertia
1271+
//
1272+
auto inertia = [&](int j, int id) -> double {
1273+
auto & cen = centers[id];
1274+
double d = 0.0;
1275+
for (int i=0; i<cen.size(); i++)
1276+
d += (cen[i] - data[j]->x[i])*(cen[i] - data[j]->x[i]);
1277+
return sqrt(d);
1278+
};
1279+
1280+
// Pack return vector
1281+
//
1282+
std::vector<int> retI;
1283+
std::vector<double> retD;
1284+
for (int j=0; j<results.size(); j++) {
1285+
retI.push_back(std::get<1>(results[j]));
1286+
retD.push_back(inertia(j, std::get<1>(results[j])));
1287+
}
12061288

1289+
return {retI, retD, kMeans.getTol()};
12071290
}
12081291

12091292
std::map<std::string, CoefClasses::CoefsPtr> expMSSA::getReconstructed(bool reconstructmean)

0 commit comments

Comments
 (0)