diff --git a/CHANGELOG.md b/CHANGELOG.md index b98d853..7cf4be3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,7 +16,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.0.2] 2026-02-22 - Improve structured outputs functionality + tutorial -[Unreleased]: https://github.com/cleanlab/tlm/compare/v0.0.2...HEAD +## [0.0.3] 2026-02-22 +- Fix `get_untrustworthy_fields()` for score outputs + +[Unreleased]: https://github.com/cleanlab/tlm/compare/v0.0.3...HEAD +[0.0.3]: https://github.com/cleanlab/tlm/commits/v0.0.3 [0.0.2]: https://github.com/cleanlab/tlm/commits/v0.0.2 [0.0.1]: https://github.com/cleanlab/tlm/commits/v0.0.1 [0.0.0]: https://github.com/cleanlab/tlm/commits/v0.0.0 diff --git a/tlm/__about__.py b/tlm/__about__.py index 3b93d0b..27fdca4 100644 --- a/tlm/__about__.py +++ b/tlm/__about__.py @@ -1 +1 @@ -__version__ = "0.0.2" +__version__ = "0.0.3" diff --git a/tlm/utils/structured_output_utils.py b/tlm/utils/structured_output_utils.py index 6fd150b..de455f7 100644 --- a/tlm/utils/structured_output_utils.py +++ b/tlm/utils/structured_output_utils.py @@ -10,7 +10,17 @@ def _get_untrustworthy_fields( display_details: bool = True, ) -> list[str]: tlm_metadata = tlm_result["metadata"] - response_text = tlm_result["response"].choices[0].message.content # type: ignore + response = tlm_result["response"] + + # for score completions + if isinstance(response, dict) and "chat_completion" in response: + response = response["chat_completion"] + + try: + response_text = response.choices[0].message.content # type: ignore + except Exception: + # sometimes tlm_result["response"] is a dictionary + response_text = response["choices"][0]["message"]["content"] # type: ignore if tlm_metadata is None or "per_field_score" not in tlm_metadata: raise ValueError( @@ -18,16 +28,18 @@ def _get_untrustworthy_fields( "`get_untrustworthy_fields()` can only be called scoring structured outputs responses." ) - try: - so_response = json.loads(response_text) - except Exception: - pass - try: - so_response = ast.literal_eval(response_text) - except Exception: - raise ValueError( - "The LLM response must be a valid JSON output (use `response_format` to specify the output format)" - ) + if isinstance(response_text, dict): + so_response = response_text + else: + try: + so_response = json.loads(response_text) + except Exception: + try: + so_response = ast.literal_eval(response_text) + except Exception: + raise ValueError( + "The LLM response must be a valid JSON output (use `response_format` to specify the output format)" + ) per_field_score = tlm_metadata["per_field_score"] per_score_details = []