From c78b111f1b0f38ef021352e5eeaba02667458b4c Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Tue, 2 Dec 2025 16:01:15 -0800 Subject: [PATCH] ANE llama runner Summary: This fixes issues with the ANE-friendly llama on iOS26. See updated readme.md for more information. A key change is decomposing SDPA into matmuls and softmax because iOS26 has a bug in its implementation of SDPA on the ANE. Differential Revision: D88083155 --- examples/apple/coreml/llama/export.py | 244 ++++--- .../apple/coreml/llama/llama_transformer.py | 388 ++++++++++- examples/apple/coreml/llama/main.cpp | 294 ++++++++ examples/apple/coreml/llama/model_manager.cpp | 644 ++++++++++++++++++ examples/apple/coreml/llama/model_manager.hpp | 218 ++++++ examples/apple/coreml/llama/readme.md | 55 +- 6 files changed, 1715 insertions(+), 128 deletions(-) create mode 100644 examples/apple/coreml/llama/main.cpp create mode 100644 examples/apple/coreml/llama/model_manager.cpp create mode 100644 examples/apple/coreml/llama/model_manager.hpp diff --git a/examples/apple/coreml/llama/export.py b/examples/apple/coreml/llama/export.py index af2fa3c74ee..1e518b47f68 100644 --- a/examples/apple/coreml/llama/export.py +++ b/examples/apple/coreml/llama/export.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import argparse +import os import coremltools as ct import torch @@ -14,6 +15,7 @@ from executorch.examples.apple.coreml.llama.llama_transformer import ( InputManager, load_model, + load_model_in_pieces_ITO, ) from executorch.examples.apple.coreml.llama.utils import ( replace_linear_with_split_linear, @@ -28,10 +30,9 @@ from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ -from torchao.utils import unwrap_tensor_subclass -def main() -> None: +def main() -> None: # noqa: C901 parser = argparse.ArgumentParser() parser.add_argument( "-n", @@ -77,7 +78,10 @@ def main() -> None: parser.add_argument( "--coreml-quantize", default=None, - choices=["b4w", "c4w"], + choices=[ + "b4w", + "c4w", + ], help="This option is only for coreml: Use coreml quantization, e.g. b4w (for blockwise 4 bit weight), c4w (for channelwise 4 bit weight)", ) parser.add_argument( @@ -102,67 +106,69 @@ def main() -> None: type=str, default="fp16", ) + parser.add_argument( + "--export_in_parts", + action="store_true", + help="Export model in 3 parts: input_block.pte, transformer_block.pte (all layers combined), output_block.pte", + ) export_args = parser.parse_args() - model = load_model( - export_args.checkpoint, - export_args.params, - max_seq_length=export_args.max_seq_length, - use_cache_list=export_args.use_cache_list, - ) float_dtype = {"fp16": torch.float16, "fp32": torch.float32}[ export_args.dtype ] # dtype for model/inputs - model.eval() - model.to(float_dtype) - - if export_args.target_split_size is not None: - replace_linear_with_split_linear( - model, - out_target_split_size=export_args.target_split_size, - out_max_splits=export_args.max_splits, - # I have not found splitting on in_features to be beneficial, - # and it often leads to OOM so I set in_max_splits to 1 - in_target_split_size=1, - in_max_splits=1, - ) + def maybe_split_model(model): + if export_args.target_split_size is not None: + replace_linear_with_split_linear( + model, + out_target_split_size=export_args.target_split_size, + out_max_splits=export_args.max_splits, + # I have not found splitting on in_features to be beneficial, + # and it often leads to OOM so I set in_max_splits to 1 + in_target_split_size=1, + in_max_splits=1, + ) - # Quantization - if export_args.embedding_quantize: - bitwidth, group_size = export_args.embedding_quantize.split(",") - bitwidth = int(bitwidth) - assert bitwidth in [4, 8], "CoreML only supports 4-bit and 8-bit quantization" - group_size = int(group_size) - if group_size == 0: - granularity = PerAxis(0) - else: - granularity = PerGroup(group_size) - weight_dtype = getattr(torch, f"int{bitwidth}") + def maybe_quantize_model(model): + if export_args.embedding_quantize: + bitwidth, group_size = export_args.embedding_quantize.split(",") + bitwidth = int(bitwidth) + assert bitwidth in [ + 4, + 8, + ], "CoreML only supports 4-bit and 8-bit quantization" + group_size = int(group_size) + if group_size == 0: + granularity = PerAxis(0) + else: + granularity = PerGroup(group_size) + weight_dtype = getattr(torch, f"int{bitwidth}") - quantize_( - model, - IntxWeightOnlyConfig(weight_dtype=weight_dtype, granularity=granularity), - lambda m, fqn: isinstance(m, torch.nn.Embedding), - ) + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=weight_dtype, granularity=granularity + ), + lambda m, fqn: isinstance(m, torch.nn.Embedding), + ) - if export_args.coreml_quantize == "b4w": - quantize_( - model, - IntxWeightOnlyConfig( - weight_dtype=torch.int4, - granularity=PerGroup(32), - ), - ) - elif export_args.coreml_quantize == "c4w": - quantize_( - model, - IntxWeightOnlyConfig( - weight_dtype=torch.int4, - granularity=PerAxis(0), - ), - ) + if export_args.coreml_quantize == "b4w": + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=PerGroup(32), + ), + ) + elif export_args.coreml_quantize == "c4w": + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=PerAxis(0), + ), + ) compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16] minimum_deployment_target=ct.target.iOS18, @@ -179,45 +185,113 @@ def main() -> None: skip_ops_for_coreml_delegation=[], ) - input_manager = InputManager( - n_layers=model.params.n_layers, - max_batch_size=model.params.max_batch_size, - n_kv_heads=model.params.n_kv_heads, - max_seq_length=model.params.max_seq_len, - head_dim=model.params.head_dim, - use_cache_list=export_args.use_cache_list, - seq_length=export_args.seq_length, - dtype=float_dtype, - minus_infinity=-30000, - cache_size=export_args.cache_size, + executorch_config = ExecutorchBackendConfig( + extract_delegate_segments=True, + do_quant_fusion_and_const_prop=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), ) - example_inputs = input_manager.get_inputs(tokens=[0]) - model = unwrap_tensor_subclass(model) + def strip_pte(name): + if name.endswith(".pte"): + return name[:-4] + else: + return name - ep = torch.export.export(model, example_inputs, strict=True) - print("Exported program") - print(ep) + if not export_args.export_in_parts: + # Mode 0: Single monolithic model + model = load_model( + export_args.checkpoint, + export_args.params, + max_seq_length=export_args.max_seq_length, + use_cache_list=export_args.use_cache_list, + ) + input_manager = InputManager( + n_layers=model.params.n_layers, + max_batch_size=model.params.max_batch_size, + n_kv_heads=model.params.n_kv_heads, + max_seq_length=model.params.max_seq_len, + head_dim=model.params.head_dim, + use_cache_list=export_args.use_cache_list, + seq_length=export_args.seq_length, + dtype=float_dtype, + minus_infinity=-30000, + cache_size=export_args.cache_size, + ) + example_inputs = input_manager.get_inputs(tokens=[0]) + model.eval() + model = model.to(float_dtype) + print("Model", model) + maybe_split_model(model) + print("Model after split", model) + maybe_quantize_model(model) + print("Model after quantize", model) - edge_manager = to_edge_transform_and_lower( - ep, - partitioner=[partitioner], - ) + ep = torch.export.export(model, example_inputs, strict=True) + ep = ep.run_decompositions({}) + print("Exported program") + print(ep) + + edge_manager = to_edge_transform_and_lower( + ep, + partitioner=[partitioner], + ) + + print("Delegated program") + print(format_delegated_graph(edge_manager.exported_program().graph_module)) - print("Delegated program") - print(format_delegated_graph(edge_manager.exported_program().graph_module)) + executorch_program = edge_manager.to_executorch(executorch_config) + filename = save_pte_program(executorch_program, export_args.output_name) + print(f"Saved Executorch program to local {filename}") - executorch_program = edge_manager.to_executorch( - ExecutorchBackendConfig( - extract_delegate_segments=True, - do_quant_fusion_and_const_prop=True, - memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), - sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), + else: + # Mode 1: Export in 3 parts with single transformer block + models, example_inputs = load_model_in_pieces_ITO( + export_args.checkpoint, + export_args.params, + max_seq_length=export_args.max_seq_length, + seq_length=export_args.seq_length, + float_dtype=float_dtype, ) - ) - filename = save_pte_program(executorch_program, export_args.output_name) - print(f"Saved Executorch program to local {filename}") + for i, model in enumerate(models): + if i == 0: + ex_inputs = example_inputs[i] + suffix = "input_block" + elif i == len(models) - 1: + ex_inputs = example_inputs[-1] + suffix = "output_block" + else: + ex_inputs = example_inputs[1] + suffix = "transformer_block" + + model.eval() + model = model.to(float_dtype) + print(f"Model {i}", model) + if i == len(models) - 1: + maybe_split_model(model) + print(f"Model {i} after split", model) + maybe_quantize_model(model) + print(f"Model {i} after quantize", model) + ep = torch.export.export(model, ex_inputs, strict=True) + ep = ep.run_decompositions({}) + print(f"Exported program for model {i}", ep) + + edge_manager = to_edge_transform_and_lower( + ep, + partitioner=[partitioner], + ) + + print(f"Delegated program for model {i}") + print(format_delegated_graph(edge_manager.exported_program().graph_module)) + + executorch_program = edge_manager.to_executorch(executorch_config) + os.makedirs(f"{strip_pte(export_args.output_name)}", exist_ok=True) + filename = save_pte_program( + executorch_program, + f"{strip_pte(export_args.output_name)}/{suffix}.pte", + ) + print(f"Saved Executorch program to local {filename}") if __name__ == "__main__": diff --git a/examples/apple/coreml/llama/llama_transformer.py b/examples/apple/coreml/llama/llama_transformer.py index ae98c327b45..0364ff99401 100644 --- a/examples/apple/coreml/llama/llama_transformer.py +++ b/examples/apple/coreml/llama/llama_transformer.py @@ -14,7 +14,7 @@ import torch import torch.nn.functional as F -from executorch.examples.models.llama.norm import RMSNorm +from executorch.examples.models.llama.norm import RMSNorm # noqa F401 from executorch.examples.models.llama.rope import ( hf_apply_rotary_emb, @@ -34,6 +34,15 @@ def find_multiple(n: int, k: int) -> int: return n + k - (n % k) +def silu_approx(x): + x = x.clamp(-3, 3) + x2 = x * x + x4 = x2 * x2 + x6 = x4 * x2 + res = 0.0017 + 0.5 * x + 0.2423 * x2 - 0.0153 * x4 + 0.00057 * x6 + return res + + @dataclass class ModelArgs: dim: int = 2048 @@ -109,6 +118,17 @@ def __post_init__(self): self.head_dim = self.dim // self.n_heads +def rms_norm_fp16_stable(x, eps=1e-5, min_scale=1e-3): + amax = x.abs().amax(dim=-1, keepdim=True) + scale = amax.clamp(min=min_scale) + x_scaled = x / scale + + var = torch.square(x_scaled).mean(dim=-1, keepdim=True) + rms = torch.sqrt(var + eps) + y = x_scaled / rms + return y + + class CoreMLRMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): """ @@ -146,10 +166,11 @@ def _norm(self, x): # In future, we want to add CoreML support for the functional RMSNorm op # We have yet to do large scale evaluations on the numeric stability of this solution, but note that # it appears better than what exists currently (removing FP32 casts and using FP16) + norm = torch.linalg.vector_norm(x, dim=-1, keepdim=True) rms_norm_eps0 = ( x - * torch.sqrt(torch.tensor(self.dim, dtype=x.dtype)) - * torch.reciprocal(torch.linalg.vector_norm(x, dim=-1, keepdim=True)) + * (torch.sqrt(torch.tensor(self.dim, dtype=x.dtype)) / norm) + # * torch.reciprocal(torch.linalg.vector_norm(x, dim=-1, keepdim=True)) ) return rms_norm_eps0 @@ -168,6 +189,12 @@ def forward(self, x): return output * self.weight +_RMS_NORM = CoreMLRMSNorm +_DECOMPOSE_SDPA = True +_USE_SOFTMAX = True +_USE_SILU_APPROX = False + + class Rope(torch.nn.Module): def __init__(self, params: ModelArgs): super().__init__() @@ -249,7 +276,15 @@ def __init__(self, args: ModelArgs): self.w3 = nn.Linear(args.dim, hidden_dim, bias=False) def forward(self, x): - return self.w2(F.silu(self.w1(x)) * self.w3(x)) + t1 = self.w1(x) + if _USE_SILU_APPROX: + t1 = silu_approx(t1) + else: + t1 = F.silu(t1) + t2 = self.w3(x) + out = t1 * t2 + out = self.w2(out) + return out class ConditionalFeedForward(nn.Module): @@ -327,8 +362,8 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): if self.use_qk_norm: q_norm_dim = self.head_dim k_norm_dim = self.head_dim - self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps) - self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps) + self.q_norm_fn = _RMS_NORM(q_norm_dim, eps=args.norm_eps) + self.k_norm_fn = _RMS_NORM(k_norm_dim, eps=args.norm_eps) def forward( self, @@ -364,14 +399,45 @@ def forward( k = torch.concat([k_cache, k], dim=2) v = torch.concat([v_cache, v], dim=2) + # TODO: I'm pretty sure the MB version of SDPA does not require this repeat_interleave, # grouped multiquery attention: expand out keys and values if self.n_rep > 1: k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) - output = torch.ops.aten.scaled_dot_product_attention.default( - q, k, v, attn_mask=attn_mask - ) + if not _DECOMPOSE_SDPA: + output = torch.ops.aten.scaled_dot_product_attention.default( + q, k, v, attn_mask=attn_mask + ) + else: + # ------------------------------ + # Manual SDPA: matmuls + softmax + # q: (B, H, T_q, D) + # k: (B, H, T_k, D) + # v: (B, H, T_k, D) + # attn_mask: broadcastable to (B, H, T_q, T_k) + # ------------------------------ + d = q.size(-1) + # (B, H, T_q, T_k) + scores = torch.matmul(q, k.transpose(-2, -1)) / (d**0.5) + + if attn_mask is not None: + # attn_mask is already used this way with SDPA, keep same semantics: + # 0.0 for allowed, -inf for disallowed, added to scores. + scores = scores + attn_mask + + if _USE_SOFTMAX: + # (B, H, T_q, T_k) + attn_weights = torch.softmax(scores, dim=-1) + else: + scores = scores.clamp(min=-60.0, max=60.0) + scores_max, _ = scores.max(dim=-1, keepdim=True) # (B, H, T_q, 1) + scores_exp = torch.exp(scores - scores_max) + attn_weights = scores_exp / scores_exp.sum(dim=-1, keepdim=True) + + # (B, H, T_q, D) + output = torch.matmul(attn_weights, v) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) output = self.wo(output) return output, new_k, new_v @@ -388,8 +454,8 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope): self.block_sparse_moe = MOEFeedForward(args) else: self.feed_forward = FeedForward(args) - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.attention_norm = _RMS_NORM(args.dim, eps=args.norm_eps) + self.ffn_norm = _RMS_NORM(args.dim, eps=args.norm_eps) def forward( self, @@ -404,12 +470,123 @@ def forward( h, new_k, new_v = self.attention.forward( norm_emb, freqs_cos, freqs_sin, k_cache, v_cache, attn_mask ) - h = x + h - out = h + self.feed_forward(self.ffn_norm(h)) + tmp = self.feed_forward(self.ffn_norm(h)) + out = h + tmp return out, new_k, new_v +class TransformerBlockSequence(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.layers = torch.nn.ModuleList() + self.rope = Rope(args) + for layer_id in range(args.n_layers): + self.layers.append(TransformerBlock(layer_id, args, self.rope)) + + def forward( + self, + x, + freqs_cos, + freqs_sin, + *args, # k_caches (n_layers tensors) + v_caches (n_layers tensors) + attn_mask + ): + # After torch.export, list arguments get flattened into individual args + # Reconstruct them: args = [k_cache_0, ..., k_cache_n, v_cache_0, ..., v_cache_n, attn_mask] + n_layers = len(self.layers) + k_caches = list(args[:n_layers]) + v_caches = list(args[n_layers : 2 * n_layers]) + attn_mask = args[2 * n_layers] + + new_k_caches = [] + new_v_caches = [] + for i, layer in enumerate(self.layers): + x, new_k, new_v = layer( + x, freqs_cos, freqs_sin, k_caches[i], v_caches[i], attn_mask + ) + new_k_caches.append(new_k) + new_v_caches.append(new_v) + + return x, new_k_caches, new_v_caches + + +class AttentionBlock(nn.Module): + def __init__(self, layer_id: int, args: ModelArgs, rope: Rope): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.head_dim = args.head_dim + self.attention = Attention(args, layer_id, rope) + self.attention_norm = _RMS_NORM(args.dim, eps=args.norm_eps) + + def forward( + self, + x, + freqs_cos, + freqs_sin, + k_cache, + v_cache, + attn_mask, + ): # x: 1xN + norm_emb = self.attention_norm(x) + h, new_k, new_v = self.attention.forward( + norm_emb, freqs_cos, freqs_sin, k_cache, v_cache, attn_mask + ) + h = x + h + return h, new_k, new_v + + +class FeedForwardBlock(nn.Module): + def __init__(self, layer_id: int, args: ModelArgs, rope: Rope): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.head_dim = args.head_dim + if args.moe: + self.block_sparse_moe = MOEFeedForward(args) + else: + self.feed_forward = FeedForward(args) + self.ffn_norm = _RMS_NORM(args.dim, eps=args.norm_eps) + + def forward( + self, + h, + ): # x: 1xN + tmp = self.feed_forward(self.ffn_norm(h)) + out = h + tmp + return out + + +class InputBlock(nn.Module): + def __init__(self, params: ModelArgs): + super().__init__() + self.rope = Rope(params) + self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) + + def forward(self, tokens: torch.LongTensor, input_pos: torch.LongTensor): + h = self.tok_embeddings(tokens) + seqlen = h.shape[1] + freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seqlen) + return h, freqs_cos, freqs_sin + + +class OutputBlock(nn.Module): + def __init__(self, params: ModelArgs): + super().__init__() + self.generate_full_logits = params.generate_full_logits + self.norm = _RMS_NORM(params.dim, eps=params.norm_eps) + self.output = nn.Linear(params.dim, params.vocab_size, bias=False) + + def forward(self, h, input_length: torch.LongTensor): + if not self.generate_full_logits: + # Only the last logit is used for the new generated token + h = h[:, input_length - 1, :].squeeze(1) + h = self.norm(h) + logits = self.output(h) + return logits + + class Transformer(nn.Module): def __init__(self, params: ModelArgs): super().__init__() @@ -422,7 +599,7 @@ def __init__(self, params: ModelArgs): self.layers = torch.nn.ModuleList() for layer_id in range(params.n_layers): self.layers.append(TransformerBlock(layer_id, params, self.rope)) - self.norm = RMSNorm(params.dim, eps=params.norm_eps) + self.norm = _RMS_NORM(params.dim, eps=params.norm_eps) self.output = nn.Linear(params.dim, params.vocab_size, bias=False) self.generate_full_logits = params.generate_full_logits self.max_seq_len = params.max_seq_len @@ -510,6 +687,189 @@ def load_model(checkpoint_path, params_path, max_seq_length, use_cache_list): return model +def load_model_in_pieces_ITO( + checkpoint_path, + params_path, + max_seq_length, + seq_length, + float_dtype: torch.dtype = torch.float16, +): + """ + Loads model in 3 pieces: input, single transformer block (all layers), output + """ + import json + + with open(params_path, "r") as f: + params = json.loads(f.read()) + + args = ModelArgs( + max_seq_len=max_seq_length, + generate_full_logits=False, + **params, + ) + + with torch.device("meta"): + in_block = InputBlock(args) + out_block = OutputBlock(args) + transformer_blocks = TransformerBlockSequence(args) + + checkpoint = torch.load( + checkpoint_path, map_location="cpu", mmap=True, weights_only=True + ) + if "model" in checkpoint: + checkpoint = checkpoint["model"] + + for model in [in_block, out_block]: + missing, unexpected = model.load_state_dict( + checkpoint, + strict=False, + assign=True, + ) + assert len(missing) == 0 + + missing, unexpected = transformer_blocks.load_state_dict( + checkpoint, + strict=False, + assign=True, + ) + assert len(missing) == 0 + + cache_shape = (1, args.n_kv_heads, max_seq_length - seq_length, args.head_dim) + attn_mask_shape = (seq_length, max_seq_length) + freqs_shape = (seq_length, args.head_dim // 2) + h_shape = (1, seq_length, args.dim) + + # Create example inputs for TransformerBlockSequence + # The forward method uses *args to receive k_caches and v_caches, so we need to unpack them + transformer_example_inputs = ( + torch.zeros(h_shape, dtype=float_dtype), # h + torch.zeros(freqs_shape, dtype=float_dtype), # freqs_cos + torch.zeros(freqs_shape, dtype=float_dtype), # freqs_sin + ) + # Add unpacked k_caches + for _ in range(args.n_layers): + transformer_example_inputs += (torch.zeros(cache_shape, dtype=float_dtype),) + # Add unpacked v_caches + for _ in range(args.n_layers): + transformer_example_inputs += (torch.zeros(cache_shape, dtype=float_dtype),) + # Add attn_mask + transformer_example_inputs += (torch.zeros(attn_mask_shape, dtype=float_dtype),) + + example_inputs = [ + ( + torch.zeros(1, seq_length, dtype=torch.int64), + torch.tensor([0], dtype=torch.long), + ), # InputBlock + transformer_example_inputs, # TransformerBlockSequence + ( # OutputBlock + torch.zeros(h_shape, dtype=float_dtype), # h + torch.tensor([0], dtype=torch.long), # input_length + ), + ] + models = [in_block, transformer_blocks, out_block] + + for i in range(len(models)): + models[i] = models[i].to(float_dtype) + + return models, example_inputs + + +def load_model_in_pieces_IAFO( + checkpoint_path, + params_path, + max_seq_length, + seq_length, + float_dtype: torch.dtype = torch.float16, +): + """ + Loads model in pieces: input, [repeating attention, feedforward], output + """ + import json + + with open(params_path, "r") as f: + params = json.loads(f.read()) + + args = ModelArgs( + max_seq_len=max_seq_length, + generate_full_logits=False, + **params, + ) + + with torch.device("meta"): + rope = Rope(args) + in_block = InputBlock(args) + out_block = OutputBlock(args) + transformer_blocks = [] + for layer_id in range(args.n_layers): + transformer_blocks.append(AttentionBlock(layer_id, args, rope)) + transformer_blocks.append(FeedForwardBlock(layer_id, args, rope)) + + checkpoint = torch.load( + checkpoint_path, map_location="cpu", mmap=True, weights_only=True + ) + if "model" in checkpoint: + checkpoint = checkpoint["model"] + + for model in [in_block, out_block]: + missing, unexpected = model.load_state_dict( + checkpoint, + strict=False, + assign=True, + ) + assert len(missing) == 0 + + for i in range(len(transformer_blocks) // 2): + layer_prefix = f"layers.{i}." + layer_checkpoint = { + k[len(layer_prefix) :]: v + for k, v in checkpoint.items() + if k.startswith(layer_prefix) + } + missing, unexpected = transformer_blocks[2 * i].load_state_dict( + layer_checkpoint, + strict=False, + assign=True, + ) + assert len(missing) == 0 + missing, unexpected = transformer_blocks[2 * i + 1].load_state_dict( + layer_checkpoint, + strict=False, + assign=True, + ) + assert len(missing) == 0 + + cache_shape = (1, args.n_kv_heads, max_seq_length - seq_length, args.head_dim) + attn_mask_shape = (seq_length, max_seq_length) + freqs_shape = (seq_length, args.head_dim // 2) + h_shape = (1, seq_length, args.dim) + + example_inputs = [ + ( + torch.zeros(1, seq_length, dtype=torch.int64), + torch.tensor([0], dtype=torch.long), + ), # InputBlock + ( # AttentionBlock + torch.zeros(h_shape, dtype=float_dtype), # h + torch.zeros(freqs_shape, dtype=float_dtype), # freqs_cos + torch.zeros(freqs_shape, dtype=float_dtype), # freqs_sin + torch.zeros(cache_shape, dtype=float_dtype), # k_cache + torch.zeros(cache_shape, dtype=float_dtype), # v_cache + torch.zeros(attn_mask_shape, dtype=float_dtype), + ), + (torch.zeros(h_shape, dtype=float_dtype),), # FeedForwardBlock # h + ( # OutputBlock + torch.zeros(h_shape, dtype=float_dtype), # h + torch.tensor([0], dtype=torch.long), # input_length + ), + ] + + models = [in_block] + transformer_blocks + [out_block] + for i in range(len(models)): + models[i] = models[i].to(float_dtype) + + return models, example_inputs + + class InputManager: class NGramCache: def __init__(self, max_size: int): diff --git a/examples/apple/coreml/llama/main.cpp b/examples/apple/coreml/llama/main.cpp new file mode 100644 index 00000000000..4dc8b310bfa --- /dev/null +++ b/examples/apple/coreml/llama/main.cpp @@ -0,0 +1,294 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "model_manager.hpp" + +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +DEFINE_string( + model_path, + "", + "Path to directory containing piecewise model files (input_block.pte, transformer_block_*.pte, output_block.pte)"); + +DEFINE_string( + tokenizer_path, + "", + "Path to tokenizer file (tokenizer.model or tokenizer.bin)"); + +DEFINE_string( + prompt, + "Once upon a time,", + "Text prompt for generation"); + +DEFINE_double( + temperature, + 0.6, + "Temperature for sampling. 0 = greedy argmax sampling (deterministic). Lower temperature = more deterministic"); + +DEFINE_int32( + max_new_tokens, + 100, + "Maximum number of new tokens to generate"); + +using executorch::aten::Tensor; +using executorch::extension::ModelManager; +using executorch::runtime::Error; +using tokenizers::Tokenizer; + +/** + * Samples the next token from logits using temperature sampling. + */ +static int64_t sample_token(const float* logits, int64_t vocab_size, double temperature) { + if (temperature == 0.0) { + // Greedy sampling - find argmax + int64_t max_idx = 0; + float max_val = logits[0]; + for (int64_t i = 1; i < vocab_size; ++i) { + if (logits[i] > max_val) { + max_val = logits[i]; + max_idx = i; + } + } + return max_idx; + } + + // Temperature sampling + std::vector probs(vocab_size); + float max_logit = logits[0]; + for (int64_t i = 1; i < vocab_size; ++i) { + if (logits[i] > max_logit) { + max_logit = logits[i]; + } + } + + // Apply temperature and softmax + float sum = 0.0f; + for (int64_t i = 0; i < vocab_size; ++i) { + probs[i] = std::exp((logits[i] - max_logit) / temperature); + sum += probs[i]; + } + for (int64_t i = 0; i < vocab_size; ++i) { + probs[i] /= sum; + } + + // Sample from the distribution + static std::random_device rd; + static std::mt19937 gen(rd()); + std::discrete_distribution dist(probs.begin(), probs.end()); + return dist(gen); +} + +int main(int argc, char** argv) { + executorch::runtime::runtime_init(); + gflags::ParseCommandLineFlags(&argc, &argv, true); + + if (FLAGS_model_path.empty()) { + ET_LOG(Error, "Must specify --model_path"); + return 1; + } + + if (FLAGS_tokenizer_path.empty()) { + ET_LOG(Error, "Must specify --tokenizer_path"); + return 1; + } + + // Load tokenizer + ET_LOG(Info, "Loading tokenizer from %s", FLAGS_tokenizer_path.c_str()); + auto tiktoken = example::get_tiktoken_for_llama(); + if (!tiktoken) { + ET_LOG(Error, "Failed to create tiktoken"); + return 1; + } + + auto load_error = tiktoken->load(FLAGS_tokenizer_path); + if (load_error != tokenizers::Error::Ok) { + ET_LOG(Error, "Failed to load tokenizer: %d", static_cast(load_error)); + return 1; + } + + // Load model + ET_LOG(Info, "Loading model from %s", FLAGS_model_path.c_str()); + std::unique_ptr model_manager; + try { + model_manager = std::make_unique(FLAGS_model_path); + } catch (const std::exception& e) { + ET_LOG(Error, "Failed to load model: %s", e.what()); + return 1; + } + + // Tokenize prompt + auto encode_res = tiktoken->encode(FLAGS_prompt, /*bos=*/1, /*eos=*/0); + if (!encode_res.ok()) { + ET_LOG(Error, "Failed to encode prompt"); + return 1; + } + std::vector prompt_tokens_u64 = std::move(*encode_res); + + // Convert to int64_t + std::vector tokens; + for (uint64_t token : prompt_tokens_u64) { + tokens.push_back(static_cast(token)); + } + + ET_LOG(Info, "Prompt has %zu tokens", tokens.size()); + std::cout << FLAGS_prompt << std::flush; + + // Get EOS token for stop check + uint64_t eos_token = tiktoken->eos_tok(); + + // Generation loop + int64_t max_seq_length = model_manager->get_max_seq_length(); + int64_t input_pos = 0; // Track position externally + + // Pre-allocate buffer for input_pos tensor + std::vector input_pos_buffer = {0}; + + // Timing statistics + bool is_prefill = true; + int64_t prefill_tokens = 0; + double prefill_time_ms = 0.0; + int64_t decode_tokens = 0; + double decode_time_ms = 0.0; + + for (int32_t i = 0; i < FLAGS_max_new_tokens && input_pos < max_seq_length; ++i) { + // Create tokens tensor [1, num_tokens] - ModelManager will handle chunking if needed + size_t num_tokens = tokens.size(); + auto tokens_tensor = executorch::extension::from_blob( + tokens.data(), + {1, static_cast(num_tokens)}, + executorch::aten::ScalarType::Long); + + // Update input_pos buffer and create tensor [1] + input_pos_buffer[0] = input_pos; + auto input_pos_tensor = executorch::extension::from_blob( + input_pos_buffer.data(), + {1}, + executorch::aten::ScalarType::Long); + + // Start timing + auto start_time = std::chrono::high_resolution_clock::now(); + + // Run forward pass with tensors + auto logits_result = model_manager->forward(*tokens_tensor, *input_pos_tensor); + if (!logits_result.ok()) { + ET_LOG(Error, "Forward pass failed: 0x%" PRIx32, + static_cast(logits_result.error())); + return 1; + } + + Tensor logits = std::move(*logits_result); + + // End timing + auto end_time = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(end_time - start_time); + double time_ms = duration.count() / 1000.0; + + // Track prefill vs decode timing + if (is_prefill) { + prefill_tokens = num_tokens; + prefill_time_ms = time_ms; + is_prefill = false; // Switch to decode after first pass + } else { + decode_tokens++; + decode_time_ms += time_ms; + } + + // Sample next token from logits + // The output block returns logits shaped [batch_size, vocab_size] for the last position + int64_t vocab_size = logits.size(logits.dim() - 1); + + // Convert fp16 logits to fp32 for sampling + const exec_aten::Half* logits_fp16 = logits.const_data_ptr(); + + // Calculate total elements + int64_t total_elements = 1; + for (int32_t d = 0; d < logits.dim(); ++d) { + total_elements *= logits.size(d); + } + + std::vector logits_fp32(total_elements); + for (int64_t j = 0; j < total_elements; ++j) { + logits_fp32[j] = static_cast(logits_fp16[j]); + } + + // Calculate the total number of positions in the logits tensor + int64_t total_positions = 1; + for (int32_t d = 0; d < logits.dim() - 1; ++d) { + total_positions *= logits.size(d); + } + + // Get pointer to last position's logits + size_t last_valid_pos_offset = (total_positions - 1) * vocab_size; + const float* last_logits = logits_fp32.data() + last_valid_pos_offset; + + int64_t next_token = sample_token(last_logits, vocab_size, FLAGS_temperature); + + // Check for stop tokens + if (static_cast(next_token) == eos_token) { + ET_LOG(Info, "Reached end-of-text token"); + break; + } + + // Decode and print + auto decode_res = tiktoken->decode(next_token, next_token); + if (decode_res.ok()) { + std::cout << *decode_res << std::flush; + } + + // Update position and prepare for next iteration + input_pos += num_tokens; // Advance position by number of tokens processed + tokens = {next_token}; // Next iteration processes just the new token + } + + std::cout << std::endl; + + // Print timing statistics + std::cout << "\n=== Generation Statistics ===" << std::endl; + + if (prefill_tokens > 0) { + double prefill_tok_per_sec = (prefill_time_ms > 0) ? (prefill_tokens * 1000.0 / prefill_time_ms) : 0.0; + std::cout << "Prefill: " << prefill_tokens << " tokens in " + << prefill_time_ms << " ms (" + << prefill_tok_per_sec << " tok/s)" << std::endl; + } + + if (decode_tokens > 0) { + double avg_decode_time_ms = decode_time_ms / decode_tokens; + double decode_tok_per_sec = (decode_time_ms > 0) ? (decode_tokens * 1000.0 / decode_time_ms) : 0.0; + std::cout << "Decode: " << decode_tokens << " tokens in " + << decode_time_ms << " ms (" + << decode_tok_per_sec << " tok/s, " + << avg_decode_time_ms << " ms/tok)" << std::endl; + } + + double total_time_ms = prefill_time_ms + decode_time_ms; + int64_t total_tokens = prefill_tokens + decode_tokens; + if (total_tokens > 0) { + double total_tok_per_sec = (total_time_ms > 0) ? (total_tokens * 1000.0 / total_time_ms) : 0.0; + std::cout << "Total: " << total_tokens << " tokens in " + << total_time_ms << " ms (" + << total_tok_per_sec << " tok/s)" << std::endl; + } + + ET_LOG(Info, "Generation complete"); + + return 0; +} diff --git a/examples/apple/coreml/llama/model_manager.cpp b/examples/apple/coreml/llama/model_manager.cpp new file mode 100644 index 00000000000..c7124062ccd --- /dev/null +++ b/examples/apple/coreml/llama/model_manager.cpp @@ -0,0 +1,644 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "model_manager.hpp" + +#include +#include + +#include +#include +#include +#include +#include + +namespace executorch { +namespace extension { + +namespace fs = std::filesystem; + +using runtime::Error; +using runtime::EValue; +using runtime::MethodMeta; +using runtime::Result; +using aten::ScalarType; +using aten::Tensor; + +ModelManager::ModelManager(const std::string& model_path) + : n_layers_(0), + max_batch_size_(0), + n_kv_heads_(0), + cache_size_(0), + head_dim_(0), + seq_length_(0), + max_seq_length_(0) { + // Infer number of layers from the single transformer_block.pte + n_layers_ = infer_n_layers(model_path); + if (n_layers_ == 0) { + std::stringstream ss; + ss << "No transformer_block.pte file found or invalid output count in directory: " << model_path; + throw std::runtime_error(ss.str()); + } + + ET_LOG(Info, "Detected piecewise model with %lld layers", (long long)n_layers_); + + // Load all model pieces + Error error = load_input_block(model_path); + if (error != Error::Ok) { + std::stringstream ss; + ss << "Failed to load input block: " << static_cast(error); + throw std::runtime_error(ss.str()); + } + + error = load_transformer_blocks(model_path); + if (error != Error::Ok) { + std::stringstream ss; + ss << "Failed to load transformer blocks: " << static_cast(error); + throw std::runtime_error(ss.str()); + } + + error = load_output_block(model_path); + if (error != Error::Ok) { + std::stringstream ss; + ss << "Failed to load output block: " << static_cast(error); + throw std::runtime_error(ss.str()); + } + + // Extract metadata from the first transformer block + error = extract_metadata(); + if (error != Error::Ok) { + std::stringstream ss; + ss << "Failed to extract metadata: " << static_cast(error); + throw std::runtime_error(ss.str()); + } + + ET_LOG(Info, "Model loaded successfully"); + ET_LOG(Info, " max_batch_size: %lld", (long long)max_batch_size_); + ET_LOG(Info, " n_kv_heads: %lld", (long long)n_kv_heads_); + ET_LOG(Info, " cache_size: %lld", (long long)cache_size_); + ET_LOG(Info, " head_dim: %lld", (long long)head_dim_); + ET_LOG(Info, " seq_length: %lld", (long long)seq_length_); + ET_LOG(Info, " max_seq_length: %lld", (long long)max_seq_length_); + + // Allocate KV caches and attention mask + // Using fp16 (2 bytes per element) + size_t bytes_per_elem = 2; + size_t cache_elem_count = max_batch_size_ * n_kv_heads_ * cache_size_ * head_dim_; + size_t cache_byte_count = cache_elem_count * bytes_per_elem; + + k_caches_data_.resize(n_layers_); + v_caches_data_.resize(n_layers_); + for (int64_t i = 0; i < n_layers_; ++i) { + // Allocate as bytes and initialize to zero + k_caches_data_[i].resize(cache_byte_count, 0); + v_caches_data_[i].resize(cache_byte_count, 0); + } + + // Allocate attention mask + size_t mask_elem_count = seq_length_ * max_seq_length_; + size_t mask_byte_count = mask_elem_count * bytes_per_elem; + mask_data_.resize(mask_byte_count, 0); + + // attn_cache = minus_infinity * torch.ones(seq_length, cache_size) // attn for past tokens + // attn_seq = torch.triu(minus_infinity * torch.ones(seq_length, seq_length), diagonal=1) // attn for current tokens + // attn_mask = concat([attn_cache, attn_seq], dim=-1) + + // Use -30000.0 as minus_infinity value to prevent under/overflow in FP16 + exec_aten::Half* mask_ptr = reinterpret_cast(mask_data_.data()); + exec_aten::Half minus_inf_half(-30000.0f); + exec_aten::Half zero_half(0.0f); + + for (int64_t row = 0; row < seq_length_; ++row) { + // First cache_size columns: all minus_inf + for (int64_t col = 0; col < cache_size_; ++col) { + mask_ptr[row * max_seq_length_ + col] = minus_inf_half; + } + // Next seq_length columns: upper triangular (diagonal=1) + // diagonal=1 means diagonal and below are 0, above diagonal is minus_inf + for (int64_t col = 0; col < seq_length_; ++col) { + if (col > row) { + mask_ptr[row * max_seq_length_ + cache_size_ + col] = minus_inf_half; + } else { + mask_ptr[row * max_seq_length_ + cache_size_ + col] = zero_half; + } + } + } + + // Initialize cache position tracking + cache_pos_ = 0; + + // Pre-create tensor views for k/v caches and attention mask + // This avoids overhead of creating tensors in the forward() loop + k_caches_.reserve(n_layers_); + v_caches_.reserve(n_layers_); + + for (int64_t i = 0; i < n_layers_; ++i) { + k_caches_.push_back(from_blob( + k_caches_data_[i].data(), + {static_cast(max_batch_size_), static_cast(n_kv_heads_), + static_cast(cache_size_), static_cast(head_dim_)}, + ScalarType::Half)); + + v_caches_.push_back(from_blob( + v_caches_data_[i].data(), + {static_cast(max_batch_size_), static_cast(n_kv_heads_), + static_cast(cache_size_), static_cast(head_dim_)}, + ScalarType::Half)); + } + + attn_mask_ = from_blob( + mask_data_.data(), + {static_cast(seq_length_), static_cast(max_seq_length_)}, + ScalarType::Half); + + // Pre-allocate buffers for forward() chunking/padding + chunk_buffer_.resize(seq_length_, 0); + pos_buffer_.resize(1, 0); + input_length_buffer_.resize(1, 0); + + // Pre-allocate EValue vector for layer execution (6 inputs per layer) + layer_inputs_.reserve(6); + + ET_LOG(Info, "Allocated KV caches and attention mask"); +} + +int64_t ModelManager::infer_n_layers(const std::string& model_path) { + // Check for the single transformer_block.pte file + std::string single_block_path = model_path + "/transformer_block.pte"; + if (!fs::exists(single_block_path)) { + ET_LOG(Error, "transformer_block.pte not found in directory: %s", model_path.c_str()); + return 0; + } + + // Load temporarily to get metadata + auto temp_module = std::make_unique(single_block_path); + Error error = temp_module->load(); + if (error != Error::Ok) { + ET_LOG(Error, "Failed to load transformer block: 0x%" PRIx32, static_cast(error)); + return 0; + } + + error = temp_module->load_method("forward"); + if (error != Error::Ok) { + ET_LOG(Error, "Failed to load forward method: 0x%" PRIx32, static_cast(error)); + return 0; + } + + // Get method metadata to determine output count + auto meta_result = temp_module->method_meta("forward"); + if (!meta_result.ok()) { + ET_LOG(Error, "Failed to get method metadata: 0x%" PRIx32, static_cast(meta_result.error())); + return 0; + } + + const MethodMeta& metadata = *meta_result; + size_t num_outputs = metadata.num_outputs(); + + // Output format: (h, k_cache_0, ..., k_cache_n, v_cache_0, ..., v_cache_n) + // So: num_outputs = 1 + n_layers + n_layers = 1 + 2*n_layers + // Therefore: n_layers = (num_outputs - 1) / 2 + if ((num_outputs - 1) % 2 != 0) { + ET_LOG(Error, "Invalid output count %zu for transformer block", num_outputs); + return 0; + } + + int64_t n_layers = (num_outputs - 1) / 2; + ET_LOG(Info, "Found transformer_block.pte with %lld layers", (long long)n_layers); + return n_layers; +} + +Error ModelManager::load_input_block(const std::string& model_path) { + std::string input_block_path = model_path + "/input_block.pte"; + + if (!fs::exists(input_block_path)) { + ET_LOG(Error, "Input block not found: %s", input_block_path.c_str()); + return Error::InvalidArgument; + } + + input_proj_module_ = std::make_unique(input_block_path); + Error error = input_proj_module_->load(); + if (error != Error::Ok) { + ET_LOG(Error, "Failed to load input block: 0x%" PRIx32, static_cast(error)); + return error; + } + + error = input_proj_module_->load_method("forward"); + if (error != Error::Ok) { + ET_LOG(Error, "Failed to load forward method for input block: 0x%" PRIx32, static_cast(error)); + return error; + } + + ET_LOG(Info, "Loaded input block"); + return Error::Ok; +} + +Error ModelManager::load_transformer_blocks(const std::string& model_path) { + // Load single transformer_block.pte file + std::string block_path = model_path + "/transformer_block.pte"; + + if (!fs::exists(block_path)) { + ET_LOG(Error, "Transformer block not found: %s", block_path.c_str()); + return Error::InvalidArgument; + } + + auto module = std::make_unique(block_path); + Error error = module->load(); + if (error != Error::Ok) { + ET_LOG(Error, "Failed to load transformer block: 0x%" PRIx32, static_cast(error)); + return error; + } + + error = module->load_method("forward"); + if (error != Error::Ok) { + ET_LOG(Error, "Failed to load forward method for transformer block: 0x%" PRIx32, static_cast(error)); + return error; + } + + transformer_modules_.push_back(std::move(module)); + ET_LOG(Info, "Loaded transformer block"); + + return Error::Ok; +} + +Error ModelManager::load_output_block(const std::string& model_path) { + std::string output_block_path = model_path + "/output_block.pte"; + + if (!fs::exists(output_block_path)) { + ET_LOG(Error, "Output block not found: %s", output_block_path.c_str()); + return Error::InvalidArgument; + } + + output_proj_module_ = std::make_unique(output_block_path); + Error error = output_proj_module_->load(); + if (error != Error::Ok) { + ET_LOG(Error, "Failed to load output block: 0x%" PRIx32, static_cast(error)); + return error; + } + + error = output_proj_module_->load_method("forward"); + if (error != Error::Ok) { + ET_LOG(Error, "Failed to load forward method for output block: 0x%" PRIx32, static_cast(error)); + return error; + } + + ET_LOG(Info, "Loaded output block"); + return Error::Ok; +} + +Error ModelManager::extract_metadata() { + if (transformer_modules_.empty()) { + ET_LOG(Error, "No transformer modules loaded"); + return Error::InvalidState; + } + + // Get metadata from the transformer block + auto meta_result = transformer_modules_[0]->method_meta("forward"); + if (!meta_result.ok()) { + ET_LOG(Error, "Failed to get method metadata: 0x%" PRIx32, static_cast(meta_result.error())); + return meta_result.error(); + } + + const MethodMeta& metadata = *meta_result; + + ET_LOG(Info, "Method metadata (transformer block):"); + ET_LOG(Info, " Number of inputs: %zu", metadata.num_inputs()); + + // Single transformer block with list interface: 3 + n_layers + n_layers + 1 = 4 + 2*n_layers + size_t expected_inputs = 4 + 2 * n_layers_; + + if (metadata.num_inputs() != expected_inputs) { + ET_LOG(Error, "Expected %zu inputs for transformer block (4 + 2*%lld layers), got %zu", + expected_inputs, (long long)n_layers_, metadata.num_inputs()); + return Error::InvalidArgument; + } + + // Extract k_cache input metadata (first k_cache is at index 3) + size_t k_cache_index = 3; + auto k_cache_meta_result = metadata.input_tensor_meta(k_cache_index); + if (!k_cache_meta_result.ok()) { + ET_LOG(Error, "Failed to get k_cache tensor metadata"); + return k_cache_meta_result.error(); + } + + const auto& k_cache_meta = *k_cache_meta_result; + auto k_cache_sizes = k_cache_meta.sizes(); + + if (k_cache_sizes.size() != 4) { + ET_LOG(Error, "Expected 4 dimensions for k_cache, got %zu", k_cache_sizes.size()); + return Error::InvalidArgument; + } + + max_batch_size_ = k_cache_sizes[0]; + n_kv_heads_ = k_cache_sizes[1]; + cache_size_ = k_cache_sizes[2]; + head_dim_ = k_cache_sizes[3]; + + // Assert that the model uses fp16 + ScalarType model_dtype = k_cache_meta.scalar_type(); + if (model_dtype != ScalarType::Half) { + ET_LOG(Error, "Model must use fp16 (Half) dtype, got dtype: %d", static_cast(model_dtype)); + return Error::InvalidArgument; + } + + ET_LOG(Info, " k_cache shape: [%lld, %lld, %lld, %lld]", + (long long)max_batch_size_, + (long long)n_kv_heads_, + (long long)cache_size_, + (long long)head_dim_); + ET_LOG(Info, " dtype: fp16"); + + // Extract mask input metadata (mask is at index 3 + 2*n_layers) + size_t mask_index = 3 + 2 * n_layers_; + auto mask_meta_result = metadata.input_tensor_meta(mask_index); + if (!mask_meta_result.ok()) { + ET_LOG(Error, "Failed to get mask tensor metadata at index %zu", mask_index); + return mask_meta_result.error(); + } + + const auto& mask_meta = *mask_meta_result; + auto mask_sizes = mask_meta.sizes(); + + if (mask_sizes.size() != 2) { + ET_LOG(Error, "Expected 2 dimensions for mask, got %zu", mask_sizes.size()); + return Error::InvalidArgument; + } + + seq_length_ = mask_sizes[0]; + max_seq_length_ = mask_sizes[1]; + + ET_LOG(Info, " mask shape: [%lld, %lld]", + (long long)seq_length_, + (long long)max_seq_length_); + + return Error::Ok; +} + +void ModelManager::reset_caches() { + // Reset KV caches to zero + for (int64_t i = 0; i < n_layers_; ++i) { + std::fill(k_caches_data_[i].begin(), k_caches_data_[i].end(), 0.0f); + std::fill(v_caches_data_[i].begin(), v_caches_data_[i].end(), 0.0f); + } + ET_LOG(Info, "Reset KV caches"); +} + +Result ModelManager::forward( + const Tensor& tokens, + const Tensor& input_pos) { + // Validate inputs + if (tokens.dim() != 2) { + ET_LOG(Error, "tokens tensor must be 2D, got %zd dimensions", tokens.dim()); + return Error::InvalidArgument; + } + + if (input_pos.dim() != 1 || input_pos.size(0) != 1) { + ET_LOG(Error, "input_pos tensor must be shape [1], got dim=%zd size=%lld", + input_pos.dim(), (long long)(input_pos.dim() > 0 ? input_pos.size(0) : 0)); + return Error::InvalidArgument; + } + + // Get input_pos value + int64_t input_pos_value = input_pos.const_data_ptr()[0]; + + // For inference, we assume all tokens in the input are valid + // (the caller is responsible for proper padding) + const int64_t* tokens_data = tokens.const_data_ptr(); + int64_t token_dim_size = tokens.size(1); // Second dimension of [batch, num_tokens] + int64_t num_tokens = token_dim_size; // Use full dimension as num_tokens + + // Case 1: num_tokens <= seq_length - pad and process once + if (num_tokens <= seq_length_) { + // Fill chunk_buffer_ with tokens, rest is already zeros (padded) + std::fill(chunk_buffer_.begin(), chunk_buffer_.end(), 0); + std::copy(tokens_data, tokens_data + num_tokens, chunk_buffer_.begin()); + + auto chunk_tensor = from_blob( + chunk_buffer_.data(), + {1, static_cast(seq_length_)}, + ScalarType::Long); + + pos_buffer_[0] = input_pos_value; + auto pos_tensor = from_blob( + pos_buffer_.data(), + {1}, + ScalarType::Long); + + // Pass actual num_tokens to forward_() so it can pass it to output_block + return forward_(*chunk_tensor, *pos_tensor, num_tokens); + } + + // Case 2: num_tokens > seq_length - process in chunks + std::optional last_logits; + int64_t current_pos = input_pos_value; + + for (int64_t offset = 0; offset < num_tokens; offset += seq_length_) { + int64_t chunk_size = std::min(seq_length_, num_tokens - offset); + + // Clear buffer and fill with chunk + std::fill(chunk_buffer_.begin(), chunk_buffer_.end(), 0); + std::copy(tokens_data + offset, tokens_data + offset + chunk_size, chunk_buffer_.begin()); + + auto chunk_tensor = from_blob( + chunk_buffer_.data(), + {1, static_cast(seq_length_)}, + ScalarType::Long); + + pos_buffer_[0] = current_pos; + auto pos_tensor = from_blob( + pos_buffer_.data(), + {1}, + ScalarType::Long); + + // Pass actual chunk_size to forward_() so it can pass it to output_block + auto result = forward_(*chunk_tensor, *pos_tensor, chunk_size); + if (!result.ok()) { + return result; + } + + last_logits = std::move(*result); + current_pos += chunk_size; + } + + return std::move(*last_logits); +} + +Result ModelManager::forward_( + const Tensor& tokens, + const Tensor& input_pos, + int64_t num_tokens) { + // Validate that tokens is exactly seq_length + if (tokens.dim() != 2 || tokens.size(1) != seq_length_) { + ET_LOG(Error, "forward_() requires tokens shape [1, %lld], got [%lld, %lld]", + (long long)seq_length_, + (long long)tokens.size(0), + (long long)tokens.size(1)); + return Error::InvalidArgument; + } + + // Extract input_pos value + if (input_pos.dim() != 1 || input_pos.size(0) != 1) { + ET_LOG(Error, "input_pos tensor must be shape [1], got dim=%zd size=%lld", + input_pos.dim(), (long long)(input_pos.dim() > 0 ? input_pos.size(0) : 0)); + return Error::InvalidArgument; + } + int64_t input_pos_value = input_pos.const_data_ptr()[0]; + + // Use the num_tokens parameter passed from forward() + // This is the actual number of valid (non-padding) tokens in the input + + if (input_pos_value + num_tokens > max_seq_length_) { + ET_LOG(Error, "Position (%lld) + num_tokens (%lld) exceeds max_seq_length (%lld)", + (long long)input_pos_value, (long long)num_tokens, (long long)max_seq_length_); + return Error::InvalidArgument; + } + + // Create input_length tensor using pre-allocated buffer + input_length_buffer_[0] = num_tokens; + auto input_length_tensor = from_blob( + input_length_buffer_.data(), + {1}, + ScalarType::Long); + + // Run input block: (tokens, input_pos) -> (h, freqs_cos, freqs_sin) + auto input_result = input_proj_module_->execute( + "forward", + {EValue(tokens), EValue(input_pos)}); + + if (!input_result.ok()) { + ET_LOG(Error, "Input block execution failed: 0x%" PRIx32, + static_cast(input_result.error())); + return input_result.error(); + } + + std::vector input_outputs = std::move(*input_result); + if (input_outputs.size() != 3) { + ET_LOG(Error, "Expected 3 outputs from input block, got %zu", input_outputs.size()); + return Error::InvalidState; + } + + Tensor h = input_outputs[0].toTensor(); + Tensor freqs_cos = input_outputs[1].toTensor(); + Tensor freqs_sin = input_outputs[2].toTensor(); + + int64_t amount_to_copy = std::min(num_tokens, cache_size_ - cache_pos_); + + update_mask(input_pos_value, amount_to_copy); + + // Single transformer block mode - one module processes all layers at once + layer_inputs_.clear(); + layer_inputs_.emplace_back(h); + layer_inputs_.emplace_back(freqs_cos); + layer_inputs_.emplace_back(freqs_sin); + + // Pass all k_caches and v_caches as separate EValues + for (int64_t i = 0; i < n_layers_; ++i) { + layer_inputs_.emplace_back(k_caches_[i]); + } + for (int64_t i = 0; i < n_layers_; ++i) { + layer_inputs_.emplace_back(v_caches_[i]); + } + layer_inputs_.emplace_back(attn_mask_); + + auto result = transformer_modules_[0]->execute("forward", layer_inputs_); + if (!result.ok()) { + ET_LOG(Error, "Single transformer block execution failed: 0x%" PRIx32, + static_cast(result.error())); + return result.error(); + } + + std::vector outputs = std::move(*result); + size_t expected_outputs = 1 + 2 * n_layers_; + if (outputs.size() != expected_outputs) { + ET_LOG(Error, "Expected %zu outputs from single transformer block (1 + 2*%lld layers), got %zu", + expected_outputs, (long long)n_layers_, outputs.size()); + return Error::InvalidState; + } + + h = outputs[0].toTensor(); + + // Update caches for all layers + for (int64_t i = 0; i < n_layers_; ++i) { + Tensor new_k = outputs[1 + i].toTensor(); + Tensor new_v = outputs[1 + n_layers_ + i].toTensor(); + update_cache(i, amount_to_copy, new_k, new_v); + } + + // Update cache position after all layers + cache_pos_ += amount_to_copy; + if (cache_pos_ >= cache_size_) { + cache_pos_ = 0; + } + + // Run output block: (h, input_length) -> (logits,) + auto output_result = output_proj_module_->execute( + "forward", + {EValue(h), EValue(*input_length_tensor)}); + + if (!output_result.ok()) { + ET_LOG(Error, "Output block execution failed: 0x%" PRIx32, + static_cast(output_result.error())); + return output_result.error(); + } + + std::vector output_outputs = std::move(*output_result); + if (output_outputs.size() != 1) { + ET_LOG(Error, "Expected 1 output from output block, got %zu", output_outputs.size()); + return Error::InvalidState; + } + + return output_outputs[0].toTensor(); +} + +void ModelManager::update_cache( + int64_t layer_id, + int64_t amount_to_copy, + const Tensor& new_k_cache, + const Tensor& new_v_cache) { + // Using fp16 (2 bytes per element) + size_t elem_size = 2; + size_t row_bytes = head_dim_ * elem_size; + + const char* new_k_bytes = reinterpret_cast(new_k_cache.const_data_ptr()); + const char* new_v_bytes = reinterpret_cast(new_v_cache.const_data_ptr()); + char* k_cache_bytes = reinterpret_cast(k_caches_data_[layer_id].data()); + char* v_cache_bytes = reinterpret_cast(v_caches_data_[layer_id].data()); + + // Copy new_cache[:, :, 0:amount_to_copy, :] -> cache[:, :, cache_pos_:cache_pos_+amount_to_copy, :] + // Optimization: Copy all positions for each (batch, head) in a single memcpy + // since positions are contiguous in memory + size_t copy_bytes = amount_to_copy * row_bytes; + + for (int64_t batch = 0; batch < max_batch_size_; ++batch) { + for (int64_t head = 0; head < n_kv_heads_; ++head) { + // Source: new_cache[batch, head, 0:amount_to_copy, :] + size_t src_offset = ((batch * n_kv_heads_ + head) * seq_length_) * row_bytes; + // Dest: persistent_cache[batch, head, cache_pos_:cache_pos_+amount_to_copy, :] + size_t dst_offset = ((batch * n_kv_heads_ + head) * cache_size_ + cache_pos_) * row_bytes; + + std::memcpy(k_cache_bytes + dst_offset, new_k_bytes + src_offset, copy_bytes); + std::memcpy(v_cache_bytes + dst_offset, new_v_bytes + src_offset, copy_bytes); + } + } +} + +void ModelManager::update_mask(int64_t input_pos, int64_t amount_to_copy) { + if (input_pos <= cache_size_) { + char* mask_bytes = reinterpret_cast(mask_data_.data()); + size_t elem_size = 2; // fp16 uses 2 bytes + + int64_t num_cols_to_zero = std::min(amount_to_copy, max_seq_length_ - input_pos); + for (int64_t row = 0; row < seq_length_; ++row) { + size_t offset = (row * max_seq_length_ + input_pos) * elem_size; + size_t count = num_cols_to_zero * elem_size; + std::memset(mask_bytes + offset, 0, count); + } + } +} + +} // namespace extension +} // namespace executorch diff --git a/examples/apple/coreml/llama/model_manager.hpp b/examples/apple/coreml/llama/model_manager.hpp new file mode 100644 index 00000000000..f237d19d10d --- /dev/null +++ b/examples/apple/coreml/llama/model_manager.hpp @@ -0,0 +1,218 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace executorch { +namespace extension { + +/** + * Manages loading and execution of a piecewise LLaMA model consisting of: + * - input_block.pte + * - transformer_block_0.pte, transformer_block_1.pte, ..., transformer_block_{n_layers-1}.pte + * - output_block.pte + * + * The C++ version automatically infers n_layers from the number of + * transformer_block_*.pte files present in the model directory. + * + * The ModelManager internally manages KV caches and provides a simple + * forward() API similar to CPU models. + */ +class ModelManager { + public: + /** + * Constructs a ModelManager by loading all model pieces from the + * specified directory. + * + * @param model_path Path to directory containing the piecewise model files + * (input_block.pte, transformer_block.pte, output_block.pte) + * + * The constructor automatically: + * - Loads the single transformer_block.pte containing all layers + * - Infers n_layers from the transformer block's output count + * - Extracts metadata to determine: max_batch_size, n_kv_heads, cache_size, + * head_dim, seq_length, max_seq_length, and float_dtype + * - Allocates and initializes KV caches + */ + explicit ModelManager(const std::string& model_path); + + // Prevent copying and moving + ModelManager(const ModelManager&) = delete; + ModelManager& operator=(const ModelManager&) = delete; + ModelManager(ModelManager&&) = delete; + ModelManager& operator=(ModelManager&&) = delete; + + ~ModelManager() = default; + + /** + * Runs inference on the model with the given tokens at the specified position. + * + * This method handles tokens of any size by: + * - Padding to seq_length if tokens < seq_length + * - Processing in seq_length chunks if tokens > seq_length + * + * @param tokens Input token IDs tensor, shape [1, num_tokens] + * @param input_pos Position tensor, shape [1] indicating where these tokens start + * + * @return Result containing the logits tensor (shape: [batch_size, vocab_size]) + */ + runtime::Result forward( + const executorch::aten::Tensor& tokens, + const executorch::aten::Tensor& input_pos); + + /** + * Resets the KV caches to zero. Useful for starting a new sequence. + */ + void reset_caches(); + + // Accessors for model metadata + int64_t get_n_layers() const { return n_layers_; } + int64_t get_max_batch_size() const { return max_batch_size_; } + int64_t get_n_kv_heads() const { return n_kv_heads_; } + int64_t get_cache_size() const { return cache_size_; } + int64_t get_head_dim() const { return head_dim_; } + int64_t get_seq_length() const { return seq_length_; } + int64_t get_max_seq_length() const { return max_seq_length_; } + + private: + /** + * Infers the number of layers by counting transformer_block_*.pte files + * in the model directory. + * + * @param model_path Path to the model directory + * @return Number of transformer blocks found + */ + int64_t infer_n_layers(const std::string& model_path); + + /** + * Loads the input projection block. + * + * @param model_path Path to the model directory + * @return Error indicating success or failure + */ + runtime::Error load_input_block(const std::string& model_path); + + /** + * Loads all transformer blocks. + * + * @param model_path Path to the model directory + * @return Error indicating success or failure + */ + runtime::Error load_transformer_blocks(const std::string& model_path); + + /** + * Loads the output projection block. + * + * @param model_path Path to the model directory + * @return Error indicating success or failure + */ + runtime::Error load_output_block(const std::string& model_path); + + /** + * Extracts metadata from the first transformer block. + * + * @return Error indicating success or failure + */ + runtime::Error extract_metadata(); + + /** + * Updates a single layer's KV caches with new cache values. + * This matches the Python InputManager._update_cache() logic. + * + * @param layer_id Which layer to update + * @param amount_to_copy Number of tokens to copy (pre-calculated) + * @param new_k_cache New K cache tensor for this layer + * @param new_v_cache New V cache tensor for this layer + */ + void update_cache( + int64_t layer_id, + int64_t amount_to_copy, + const executorch::aten::Tensor& new_k_cache, + const executorch::aten::Tensor& new_v_cache); + + /** + * Updates the attention mask for the current input position. + * This matches the Python InputManager.update() mask update logic. + * + * @param input_pos Current input position + * @param amount_to_copy Number of tokens to update in the mask + */ + void update_mask(int64_t input_pos, int64_t amount_to_copy); + + /** + * Private core inference method that requires tokens to be exactly seq_length. + * This is the actual forward pass implementation - the public forward() method + * handles chunking/padding and calls this. + * + * @param tokens Input token IDs tensor, shape [1, seq_length] (must be exact) + * @param input_pos Position tensor, shape [1] + * @param num_tokens Actual number of valid (non-padding) tokens in the input + * + * @return Result containing the logits tensor (shape: [batch_size, vocab_size]) + */ + runtime::Result forward_( + const executorch::aten::Tensor& tokens, + const executorch::aten::Tensor& input_pos, + int64_t num_tokens); + + // Model pieces + std::unique_ptr input_proj_module_; + std::unique_ptr output_proj_module_; + std::vector> transformer_modules_; + + // Model metadata extracted from the transformer block + int64_t n_layers_; + int64_t max_batch_size_; + int64_t n_kv_heads_; + int64_t cache_size_; + int64_t head_dim_; + int64_t seq_length_; + int64_t max_seq_length_; + + // KV cache storage (managed internally) + // KV cache data storage (one per layer): [max_batch_size, n_kv_heads, cache_size, head_dim] + // Using uint8_t to store raw bytes (supports both FP16 and FP32) + std::vector> k_caches_data_; + std::vector> v_caches_data_; + + // KV cache tensor views (pre-created for each layer to avoid overhead in forward()) + std::vector> k_caches_; + std::vector> v_caches_; + + // Attention mask data storage: [seq_length, max_seq_length] + // Using uint8_t to store raw bytes (supports both FP16 and FP32) + std::vector mask_data_; + + // Attention mask tensor view (pre-created to avoid overhead in forward()) + std::shared_ptr attn_mask_; + + // Cache position tracking (managed internally) + int64_t cache_pos_; + + // Pre-allocated buffers for forward() chunking/padding (to avoid allocation overhead) + std::vector chunk_buffer_; // Size: seq_length + std::vector pos_buffer_; // Size: 1 + std::vector input_length_buffer_; // Size: 1 + + // Pre-allocated EValue vector for layer execution (to avoid repeated allocations) + std::vector layer_inputs_; +}; + +} // namespace extension +} // namespace executorch diff --git a/examples/apple/coreml/llama/readme.md b/examples/apple/coreml/llama/readme.md index 14dff0c8580..6b5622e05c3 100644 --- a/examples/apple/coreml/llama/readme.md +++ b/examples/apple/coreml/llama/readme.md @@ -2,45 +2,42 @@ This directory contains ANE-friendly Llama models. -Export model with: +You can export Llama1B to run on the ANE with: ``` -python export.py -n /path/to/output/model.pte -p /path/to/params.json -c /path/to/model.pth --seq_length 64 --max_seq_length 1024 --coreml-quantize c4w --dtype fp16 +python export.py -n /path/to/output/model_dir -p /path/to/params.json -c /path/to/model.pth --seq_length 64 --max_seq_length 1024 --dtype fp16 --coreml-quantize c4w -E "8,0" --target_split_size 1024 --max_splits 32 --export_in_parts ``` -(Note the script should be run from the executorch/examples/apple/coreml/llama directory.) +This exports Llama1B in 3 pieces: +* model_dir/input_block.pte (embedding/freq calculations) +* model_dir/transformer_block.pte (transformer layers) +* model_dir/output_block.pte (lm_head) -The runner is written in python and is only intended to serve as an example for how the model inputs should be processed; it is not performant. +The model is exported in fp16 and quantized with 4-bit channelwise linear layers, and 8-bit embeddings. These quantization settings are ANE-friendly, but do require some QAT to get good accuracy. To run the model, use: - -Run model with: ``` -python run.py -m /path/to/model.pte -t /path/to/tokenizer.model --prompt "Once upon a time," +buck run :llama_mainAppleMac -- \ + --model_path /path/to/output/model_dir \ + --tokenizer_path /path/to/tokenizer.model \ + --prompt "Once upon a time," \ + --temperature 0.6 \ + --max_new_tokens 100 ``` -The runner can also be used to run an eager model model to compare with CoreML numerics (--use_eager). In this case, you must specify: -* --checkpoint -* --dtype -* --max_seq_length -* --seq_length +The model is static and predicts 64 tokens at a time, with padding or chunking used to handle sequences of different lengths (this is controlled by seq_length in the export args). On iOS 26 / iPhone 15 Pro, the prediction time for 64 tokens is 0.03255s, so the approximate performance is: +* Decode: 1 tokens / 0.03255s = 30 tok/sec. +* Prefill: 64 tokens / 0.03255s = 1966 tok/sec -(Note the script should be run from the executorch/examples/apple/coreml/llama directory.) +Note that the ANE performance on M1 Mac Pro is much slower than the ANE performance on iPhone 15 Pro, so if you measure on desktop, you should expect worse performance. +### Exporting in one piece -## Export args -* seq_length: the number of tokens processed by the model. Sequences shorter than seq_length must be padded, and sequences longer than it must be chunked. -* max_seq_length: the maximum context tokens that can be processed. -* cache_size: the size of the KV cache sequences. This parameter is optional, and defaults to max_seq_length - seq_length. If a smaller cache_size is used, older tokens are evicted from the cache and no longer play a role in attention. For example, if max_seq_length=1024, but cache_size is 512, the model can generate up to 1024 tokens, but only the current tokens and the previous 512 will participate in attention. In terms of computation, cache_size plays a similar role to max_seq_length in models without cache eviction. -* use_cache_list: boolean option that controls whether KV caches are passed as a list of 4D tensors, one per layer, or if they are passed as one 5D tensor. (Note that use_cache_list does not work with ExecuTorch pybindings.) -* target_split_size: this option splits linear layers into chunks of target size. For example, if target_split_size is 1024, a linear layer with (in_features=512, out_features=8096) will be split into 8 linear layers with (in_features=512, out_features=1024) and the results concatted. If not specified, the default is no splitting. -* max_splits: this controls the maximum number of splits for linear layers. It is only relevant if target_size is passed and defaults to 8. - -## Llama1B on iPhone 15 +We can also export the model in one piece by leaving off the export arg (--export_in_parts): +``` +python export.py -n /path/to/output/model.pte -p /path/to/params.json -c /path/to/model.pth --seq_length 64 --max_seq_length 1024 --dtype fp16 --coreml-quantize c4w -E "8,0" --target_split_size 1024 --max_splits 8 +``` -We are actively experimenting with different settings. But here are ones that we've found work well for Llama1B on iPhone 15 Pro: +We have observed this leads to significantly slower performance on iPhone. We do not have a C++ runner for a model that is exported in one piece, but you can test it in python with: -* Set use_cache_list. -* Use seq_length = 32, which offers a good balance between prefill/decode performance. -* Split out_features in linear layers with target_split_size=1024, max_splits=8. -* For ANE, set dtype = fp16, coreml-quantize = c4w. The requires doing QAT on Llama1B for good accuracy. -* Set embedding-quantize to "4,32". -* Set max_seq_length to 128, 256, 512, 1024, and 2048, depending on needed context. Note that performance drops with max_seq_length. More specifically, performance drops with cache_size, and the best experience may require a good cache eviction policy. The python runner in run.py uses a last-in-last-out policy when cache_size is specified. +``` +python run.py -m /path/to/model.pte -t /path/to/tokenizer.model --prompt "Once upon a time," +```