Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file modified gradlew
100644 → 100755
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ public static void extractSoundFiles() {
new File(WerewolfApplication.class.getProtectionDomain().getCodeSource().getLocation().toURI()));
File soundFolder = new File("sounds");
if (!soundFolder.exists()) {
soundFolder.mkdir();
if (!soundFolder.mkdir()) {
log.error("Failed to create sounds directory");
return;
}
}
// Logic to clean and extract (simplified from original to avoid full deletion
// risk if not intent)
Expand Down
17 changes: 17 additions & 0 deletions src/main/java/dev/robothanzo/werewolf/config/SessionConfig.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,26 @@
package dev.robothanzo.werewolf.config;

import org.mongodb.spring.session.config.annotation.web.http.EnableMongoHttpSession;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.session.web.http.CookieSerializer;
import org.springframework.session.web.http.DefaultCookieSerializer;

@Configuration
@EnableMongoHttpSession(collectionName = "http_sessions")
public class SessionConfig {

@Bean
public CookieSerializer cookieSerializer() {
DefaultCookieSerializer serializer = new DefaultCookieSerializer();
serializer.setSameSite("Lax");
serializer.setUseHttpOnlyCookie(true);

// Use secure cookies for HTTPS environments (production)
String dashboardUrl = System.getenv().getOrDefault("DASHBOARD_URL", "http://localhost:5173");
boolean isSecureEnvironment = dashboardUrl.startsWith("https://");
serializer.setUseSecureCookie(isSecureEnvironment);

return serializer;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ public WebSocketConfig(GlobalWebSocketHandler globalWebSocketHandler) {
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(globalWebSocketHandler, "/ws")
.addInterceptors(new org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor())
.setAllowedOrigins("*");
.setAllowedOrigins(
"http://localhost:5173",
"https://wolf.robothanzo.dev",
"http://wolf.robothanzo.dev");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import io.mokulu.discord.oauth.DiscordOAuth;
import io.mokulu.discord.oauth.model.TokensResponse;
import io.mokulu.discord.oauth.model.User;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpSession;
import lombok.RequiredArgsConstructor;
Expand Down Expand Up @@ -41,13 +42,20 @@ public void login(@RequestParam(name = "guild_id", required = false) String guil
}

@GetMapping("/callback")
public void callback(@RequestParam String code, @RequestParam String state, HttpSession session,
HttpServletResponse response) throws IOException {
public void callback(@RequestParam String code, @RequestParam String state,
HttpServletRequest request, HttpServletResponse response) throws IOException {
try {
TokensResponse tokenResponse = discordOAuth.getTokens(code);
DiscordAPI discordAPI = new DiscordAPI(tokenResponse.getAccessToken());
User user = discordAPI.fetchUser();

// Prevent session fixation: invalidate old session and create new one
HttpSession oldSession = request.getSession(false);
if (oldSession != null) {
oldSession.invalidate();
}
HttpSession session = request.getSession(true);

// Store user in Session
AuthSession authSession = AuthSession.builder()
.userId(user.getId())
Expand All @@ -60,7 +68,7 @@ public void callback(@RequestParam String code, @RequestParam String state, Http

if (!"no_guild".equals(state)) {
try {
// Validate it's a number
// Validate it's a number - prevents open redirect attacks
long gid = Long.parseLong(state);
authSession.setGuildId(state); // Store as String

Expand All @@ -72,14 +80,22 @@ public void callback(@RequestParam String code, @RequestParam String state, Http
} else {
authSession.setRole(UserRole.SPECTATOR);
}

session.setAttribute("user", authSession);
response.sendRedirect(
System.getenv().getOrDefault("DASHBOARD_URL", "http://localhost:5173") + "/server/" + state);
} catch (NumberFormatException e) {
// Invalid guild ID format - redirect to server selection instead
log.warn("Invalid guild ID in OAuth state: {}", state);
authSession.setRole(UserRole.PENDING);
session.setAttribute("user", authSession);
response.sendRedirect(System.getenv().getOrDefault("DASHBOARD_URL", "http://localhost:5173") + "/");
} catch (Exception e) {
log.warn("Failed to set initial guild info: {}", state, e);
authSession.setRole(UserRole.PENDING);
session.setAttribute("user", authSession);
response.sendRedirect(System.getenv().getOrDefault("DASHBOARD_URL", "http://localhost:5173") + "/");
}

session.setAttribute("user", authSession);
response.sendRedirect(
System.getenv().getOrDefault("DASHBOARD_URL", "http://localhost:5173") + "/server/" + state);
} else {
authSession.setRole(UserRole.PENDING);
session.setAttribute("user", authSession);
Expand Down
38 changes: 23 additions & 15 deletions src/main/java/dev/robothanzo/werewolf/listeners/ButtonListener.java
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,18 @@ public void onButtonInteraction(@NotNull ButtonInteractionEvent event) {
event.getHook().editOriginal(":x: 你曾經參選過或正在參選,不得投票").queue();
return;
}
Candidate electedCandidate = candidates
.get(Integer.parseInt(customId.replaceAll("votePolice", "")));
if (electedCandidate != null) {
handleVote(event, candidates, electedCandidate);
// Broadcast update immediately
WerewolfApplication.gameSessionService.broadcastSessionUpdate(session);
} else {
event.getHook().editOriginal(":x: 找不到候選人").queue();
try {
int candidateId = Integer.parseInt(customId.replace("votePolice", ""));
Candidate electedCandidate = candidates.get(candidateId);
if (electedCandidate != null) {
handleVote(event, candidates, electedCandidate);
// Broadcast update immediately
WerewolfApplication.gameSessionService.broadcastSessionUpdate(session);
} else {
event.getHook().editOriginal(":x: 找不到候選人").queue();
}
} catch (NumberFormatException e) {
event.getHook().editOriginal(":x: 無效的投票選項").queue();
}
} else {
event.getHook().editOriginal(":x: 投票已過期").queue();
Expand All @@ -106,13 +110,17 @@ public void onButtonInteraction(@NotNull ButtonInteractionEvent event) {
event.getHook().editOriginal(":x: 你正在和別人進行放逐辯論,不得投票").queue();
return;
}
Map<Integer, Candidate> candidates = Poll.expelCandidates
.get(Objects.requireNonNull(event.getGuild()).getIdLong());
Candidate electedCandidate = candidates
.get(Integer.parseInt(customId.replaceAll("voteExpel", "")));
handleVote(event, candidates, electedCandidate);
// Broadcast update immediately for expel (user requested realtime voting)
WerewolfApplication.gameSessionService.broadcastSessionUpdate(session);
try {
Map<Integer, Candidate> candidates = Poll.expelCandidates
.get(Objects.requireNonNull(event.getGuild()).getIdLong());
int candidateId = Integer.parseInt(customId.replace("voteExpel", ""));
Candidate electedCandidate = candidates.get(candidateId);
handleVote(event, candidates, electedCandidate);
// Broadcast update immediately for expel (user requested realtime voting)
WerewolfApplication.gameSessionService.broadcastSessionUpdate(session);
} catch (NumberFormatException e) {
event.getHook().editOriginal(":x: 無效的投票選項").queue();
}
} else {
event.getHook().editOriginal(":x: 投票已過期").queue();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,14 @@ public List<Map<String, Object>> playersToJSON(Session session) {
}

players.sort((a, b) -> {
int idA = Integer.parseInt((String) a.get("id"));
int idB = Integer.parseInt((String) b.get("id"));
return Integer.compare(idA, idB);
try {
int idA = Integer.parseInt((String) a.get("id"));
int idB = Integer.parseInt((String) b.get("id"));
return Integer.compare(idA, idB);
} catch (NumberFormatException e) {
// If parsing fails, treat as equal or use string comparison as fallback
return 0;
}
});

return players;
Expand Down
50 changes: 49 additions & 1 deletion src/main/java/dev/robothanzo/werewolf/utils/IdentityUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,40 @@

import dev.robothanzo.werewolf.database.documents.AuthSession;
import dev.robothanzo.werewolf.database.documents.UserRole;
import dev.robothanzo.werewolf.service.DiscordService;
import lombok.RequiredArgsConstructor;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component;

import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;

@Component
@RequiredArgsConstructor
public class IdentityUtils {

private final DiscordService discordService;

// Simple cache to reduce Discord API calls - entries expire after they're created
private static class MembershipCacheEntry {
final boolean isMember;
final long timestamp;

MembershipCacheEntry(boolean isMember) {
this.isMember = isMember;
this.timestamp = System.currentTimeMillis();
}

boolean isExpired() {
// Cache for 5 minutes
return System.currentTimeMillis() - timestamp > 300_000;
}
}

private final Map<String, MembershipCacheEntry> membershipCache = new ConcurrentHashMap<>();

public Optional<AuthSession> getCurrentUser() {
Authentication auth = SecurityContextHolder.getContext().getAuthentication();
if (auth != null && auth.getPrincipal() instanceof AuthSession) {
Expand All @@ -21,7 +46,30 @@ public Optional<AuthSession> getCurrentUser() {

public boolean hasAccess(long guildId) {
return getCurrentUser()
.map(user -> String.valueOf(guildId).equals(user.getGuildId()))
.map(user -> {
// First check if the guild ID matches what's stored in the session
if (!String.valueOf(guildId).equals(user.getGuildId())) {
return false;
}

// Check cache first
String cacheKey = guildId + ":" + user.getUserId();
MembershipCacheEntry cached = membershipCache.get(cacheKey);
if (cached != null && !cached.isExpired()) {
return cached.isMember;
}

// Cache miss or expired - verify the user is still actually a member of the guild
boolean isMember = discordService.getMember(guildId, user.getUserId()) != null;
membershipCache.put(cacheKey, new MembershipCacheEntry(isMember));

// Clean up expired entries periodically (simple approach)
if (membershipCache.size() > 1000) {
membershipCache.entrySet().removeIf(e -> e.getValue().isExpired());
}

return isMember;
})
.orElse(false);
}

Expand Down