@@ -1955,6 +1955,17 @@ void llama_context::opt_epoch_iter(
19551955 // }
19561956 llama_ubatch ubatch = kv_self->ubatch_next (sbatch, cparams.n_ubatch , embd_pooled);
19571957
1958+ n_outputs = ubatch.n_tokens ;
1959+
1960+ printf (" ubatch.n_tokens = %d\n " , ubatch.n_tokens );
1961+
1962+ // TODO: not sure if this is needed
1963+ if (!kv_self->find_slot (ubatch)) {
1964+ LLAMA_LOG_WARN (" %s: failed to find KV cache slot for ubatch of size %d\n " , __func__, ubatch.n_tokens );
1965+
1966+ GGML_ABORT (" TODO: handle this error" );
1967+ }
1968+
19581969 auto * gf = graph_init ();
19591970 auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
19601971
@@ -1970,7 +1981,7 @@ void llama_context::opt_epoch_iter(
19701981 };
19711982 ctx_compute_opt = ggml_init (params);
19721983 }
1973- ggml_opt_prepare_alloc (opt_ctx, ctx_compute_opt, gf, res->get_tokens (), ggml_graph_node (gf, - 1 ));
1984+ ggml_opt_prepare_alloc (opt_ctx, ctx_compute_opt, gf, res->get_tokens (), res-> get_logits ( ));
19741985 ggml_opt_alloc (opt_ctx, train);
19751986 // llama_set_inputs(*lctx, ubatch);
19761987 res->set_inputs (&ubatch);
0 commit comments