Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,11 @@ func standardMoves(pos *Position, first bool, unsafeOnly bool) []Move {
// Reuse move struct by setting fields directly
m.s1 = Square(s1)
m.s2 = Square(s2)
m.tags = 0 // Reset tags

if (p == WhitePawn && Square(s2).Rank() == Rank8) || (p == BlackPawn && Square(s2).Rank() == Rank1) {
for _, pt := range promoPieceTypes {
m.promo = pt
addTags(&m, pos)
m.tags = moveTags(m, pos)
if m.HasTag(inCheck) == unsafeOnly {
// Copy the valid move to the array
moves[count] = m
Expand All @@ -152,8 +151,8 @@ func standardMoves(pos *Position, first bool, unsafeOnly bool) []Move {
}
} else {
m.promo = 0
addTags(&m, pos)
if m.HasTag(inCheck) == unsafeOnly {
m.tags = moveTags(m, pos)
if m.HasTag(inCheck) == unsafeOnly {
moves[count] = m
Comment on lines 152 to 156
count++
if first {
Expand All @@ -173,41 +172,46 @@ func standardMoves(pos *Position, first bool, unsafeOnly bool) []Move {
return result
}

// addTags updates a move's tags based on the resulting position.
// moveTags computes all tags for a move from scratch based on the resulting position.
// Tags include:
// - Capture: The move captures an opponent's piece
// - EnPassant: The move is an en passant capture
// - Check: The move puts the opponent in check
// - inCheck: The move leaves the moving side's king in check (illegal)
// - KingSideCastle: The move is a king-side castle
// - QueenSideCastle: The move is a queen-side castle
func addTags(m *Move, pos *Position) {
func moveTags(m Move, pos *Position) MoveTag {
var tags MoveTag
p := pos.board.Piece(m.s1)
if pos.board.isOccupied(m.s2) {
m.AddTag(Capture)
tags |= Capture
} else if m.s2 == pos.enPassantSquare && p.Type() == Pawn {
m.AddTag(EnPassant)
tags |= EnPassant
}
// determine if move is castle
if (p == WhiteKing && m.s1 == E1) || (p == BlackKing && m.s1 == E8) {
switch m.s2 {
case C1, C8:
m.AddTag(QueenSideCastle)
tags |= QueenSideCastle
case G1, G8:
m.AddTag(KingSideCastle)
tags |= KingSideCastle
}
}
// apply preliminary tags to a local copy so board.update reads them correctly
local := m
local.tags = tags
// determine if in check after move (makes move invalid)
cp := pos.copy()
cp.board.update(m)
cp.board.update(&local)
if isInCheck(cp) {
m.AddTag(inCheck)
tags |= inCheck
}
// determine if opponent in check after move
cp.turn = cp.turn.Other()
if isInCheck(cp) {
m.AddTag(Check)
tags |= Check
}
return tags
}

// isInCheck returns true if the side to move is in check in the given position.
Expand Down Expand Up @@ -344,8 +348,7 @@ func castleMoves(pos *Position) []Move {
!squaresAreAttacked(pos, F1, G1) &&
!pos.inCheck {
m := Move{s1: E1, s2: G1}
m.AddTag(KingSideCastle)
addTags(&m, pos)
m.tags = moveTags(m, pos)
moves[count] = m
count++
}
Expand All @@ -356,8 +359,7 @@ func castleMoves(pos *Position) []Move {
!squaresAreAttacked(pos, C1, D1) &&
!pos.inCheck {
m := Move{s1: E1, s2: C1}
m.AddTag(QueenSideCastle)
addTags(&m, pos)
m.tags = moveTags(m, pos)
moves[count] = m
count++
}
Expand All @@ -368,8 +370,7 @@ func castleMoves(pos *Position) []Move {
!squaresAreAttacked(pos, F8, G8) &&
!pos.inCheck {
m := Move{s1: E8, s2: G8}
m.AddTag(KingSideCastle)
addTags(&m, pos)
m.tags = moveTags(m, pos)
moves[count] = m
count++
}
Expand All @@ -380,8 +381,7 @@ func castleMoves(pos *Position) []Move {
!squaresAreAttacked(pos, C8, D8) &&
!pos.inCheck {
m := Move{s1: E8, s2: C8}
m.AddTag(QueenSideCastle)
addTags(&m, pos)
m.tags = moveTags(m, pos)
moves[count] = m
count++
}
Expand Down
66 changes: 64 additions & 2 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func BenchmarkStandardMoves_BoardDensity(b *testing.B) {
}
}

func TestAddTags(t *testing.T) {
func TestMoveTags(t *testing.T) {
tests := []struct {
name string
move Move
Expand Down Expand Up @@ -165,7 +165,7 @@ func TestAddTags(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
addTags(&test.move, mustPosition(test.fen))
test.move.tags = moveTags(test.move, mustPosition(test.fen))

if test.move.tags != test.want {
t.Errorf("fen: %s | move: %s\ntags(%d) == expected_tags(%d)", test.fen, test.move.String(), test.move.tags, test.want)
Expand All @@ -182,6 +182,68 @@ func TestUnsafeMoves_StartingPosition(t *testing.T) {
}
}

func TestPromotionCheckTagIsolation(t *testing.T) {
// FEN from issue #112: k7/4P3/8/8/8/8/8/K7 w - - 0 1
// White pawn on E7 can promote to E8. Only Queen and Rook give check
// to the black king on A8. Bishop and Knight should NOT have Check.
pos := mustPosition("k7/4P3/8/8/8/8/8/K7 w - - 0 1")
moves := pos.ValidMoves()

// Verify total move count (3 king moves + 4 promotions = 7)
if len(moves) != 7 {
t.Fatalf("expected 7 moves, got %d", len(moves))
}

// Helper: find move by s1, s2, promo fields (order-independent)
findMove := func(s1, s2 Square, promo PieceType) (Move, bool) {
for _, m := range moves {
if m.s1 == s1 && m.s2 == s2 && m.promo == promo {
return m, true
}
}
return Move{}, false
}

tests := []struct {
name string
promo PieceType
wantCheck bool
}{
{"queen promotion with check", Queen, true},
{"rook promotion with check", Rook, true},
{"bishop promotion without check", Bishop, false},
{"knight promotion without check", Knight, false},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
m, ok := findMove(E7, E8, test.promo)
if !ok {
t.Fatalf("expected to find E7-E8=%s promotion move", test.promo.String())
}
gotCheck := m.HasTag(Check)
if gotCheck != test.wantCheck {
t.Errorf("E7-E8=%s: Check=%v, want %v", test.promo.String(), gotCheck, test.wantCheck)
}
})
}
}

func TestPromotionNoCheck(t *testing.T) {
// Position where no promotion gives check against an existing black king.
// Black king on A6 is not attacked by any promoted piece on E8.
pos := mustPosition("8/4P3/k7/8/8/8/8/7K w - - 0 1")
moves := pos.ValidMoves()

for _, m := range moves {
if m.s1 == E7 && m.s2 == E8 {
if m.HasTag(Check) {
t.Errorf("E7-E8=%s should NOT have Check tag", m.promo.String())
}
}
}
}

func TestUnsafeMoves_PinnedKnight(t *testing.T) {
pos := mustPosition("4k3/8/8/8/1b6/2N5/8/4K3 w - - 0 1")
moves := engine{}.UnsafeMoves(pos)
Expand Down
2 changes: 1 addition & 1 deletion notation.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func (UCINotation) Decode(pos *Position, s string) (*Move, error) {
return &m, nil
}

addTags(&m, pos)
m.tags = moveTags(m, pos)

m.position = pos.Update(&m)

Expand Down
Loading