|
| 1 | +import json |
1 | 2 | from unittest import TestCase |
2 | 3 |
|
3 | 4 | from django.db import models |
|
6 | 7 | from mock import patch, Mock, PropertyMock |
7 | 8 |
|
8 | 9 | from django_elasticsearch_dsl import fields |
9 | | -from django_elasticsearch_dsl.documents import DocType |
| 10 | +from django_elasticsearch_dsl.documents import DocType, Document |
10 | 11 | from django_elasticsearch_dsl.exceptions import (ModelFieldNotMappedError, |
11 | 12 | RedeclaredFieldError) |
12 | 13 | from django_elasticsearch_dsl.registries import registry |
13 | 14 | from tests import ES_MAJOR_VERSION |
14 | 15 |
|
| 16 | +from .models import Article |
| 17 | +from .documents import ArticleDocument, ArticleWithSlugAsIdDocument |
| 18 | + |
15 | 19 |
|
16 | 20 | class Car(models.Model): |
17 | 21 | name = models.CharField(max_length=255) |
@@ -346,3 +350,64 @@ def test_init_prepare_results(self): |
346 | 350 | self.assertEqual(sorted([tuple(x) for x in m.method_calls], key=lambda _: _[0]), |
347 | 351 | [('name', (), {}), ('price', (), {}), ('type', (), {})] |
348 | 352 | ) |
| 353 | + |
| 354 | + # Mock the elasticsearch connection because we need to execute the bulk so that the generator |
| 355 | + # got iterated and generate_id called. |
| 356 | + # If we mock the bulk in django_elasticsearch_dsl.document |
| 357 | + # the actual bulk will be never called and the test will fail |
| 358 | + @patch('elasticsearch_dsl.connections.Elasticsearch.bulk') |
| 359 | + def test_default_generate_id_is_called(self, _): |
| 360 | + article = Article( |
| 361 | + id=124594, |
| 362 | + slug='some-article', |
| 363 | + ) |
| 364 | + @registry.register_document |
| 365 | + class ArticleDocument(DocType): |
| 366 | + class Django: |
| 367 | + model = Article |
| 368 | + fields = [ |
| 369 | + 'slug', |
| 370 | + ] |
| 371 | + |
| 372 | + class Index: |
| 373 | + name = 'test_articles' |
| 374 | + settings = { |
| 375 | + 'number_of_shards': 1, |
| 376 | + 'number_of_replicas': 0, |
| 377 | + } |
| 378 | + |
| 379 | + with patch.object(ArticleDocument, 'generate_id', |
| 380 | + return_value=article.id) as patched_method: |
| 381 | + d = ArticleDocument() |
| 382 | + d.update(article) |
| 383 | + patched_method.assert_called() |
| 384 | + |
| 385 | + @patch('elasticsearch_dsl.connections.Elasticsearch.bulk') |
| 386 | + def test_custom_generate_id_is_called(self, mock_bulk): |
| 387 | + article = Article( |
| 388 | + id=54218, |
| 389 | + slug='some-article-2', |
| 390 | + ) |
| 391 | + |
| 392 | + @registry.register_document |
| 393 | + class ArticleDocument(DocType): |
| 394 | + class Django: |
| 395 | + model = Article |
| 396 | + fields = [ |
| 397 | + 'slug', |
| 398 | + ] |
| 399 | + |
| 400 | + class Index: |
| 401 | + name = 'test_articles' |
| 402 | + |
| 403 | + @classmethod |
| 404 | + def generate_id(cls, article): |
| 405 | + return article.slug |
| 406 | + |
| 407 | + d = ArticleDocument() |
| 408 | + d.update(article) |
| 409 | + |
| 410 | + # Get the data from the elasticsearch low level API because |
| 411 | + # The generator get executed there. |
| 412 | + data = json.loads(mock_bulk.call_args[0][0].split("\n")[0]) |
| 413 | + assert data["index"]["_id"] == article.slug |
0 commit comments