HaoxingChen commited on
Commit
c5e3702
·
verified ·
1 Parent(s): 25b364e

Update modeling_llada2uni_moe.py

Browse files
Files changed (1) hide show
  1. modeling_llada2uni_moe.py +7 -4
modeling_llada2uni_moe.py CHANGED
@@ -2419,9 +2419,12 @@ class LLaDA2MoeModelLM(LLaDA2MoePreTrainedModel, GenerationMixin):
2419
  tok = self._get_tokenizer(tokenizer)
2420
  sp = self._get_special_tokens(tok, image_h, image_w)
2421
 
2422
- img_header = self._build_image_header(sp)
2423
- pfx = tok(question).input_ids if question else []
2424
- ids = img_header + image_tokens + sp["eoi"] + pfx
 
 
 
2425
 
2426
  out = self.generate_bd(
2427
  data={"input_ids": torch.tensor(ids).unsqueeze(0).to(self.device)},
@@ -2433,7 +2436,7 @@ class LLaDA2MoeModelLM(LLaDA2MoePreTrainedModel, GenerationMixin):
2433
  image_keep_ratio=image_keep_ratio, text_keep_ratio=text_keep_ratio,
2434
  show_progress=False,
2435
  )
2436
- return tok.decode(out[0][len(ids) - len(pfx):], skip_special_tokens=True)
2437
 
2438
  @torch.no_grad()
2439
  def edit_image(self, image_tokens, image_h, image_w, instruction,
 
2419
  tok = self._get_tokenizer(tokenizer)
2420
  sp = self._get_special_tokens(tok, image_h, image_w)
2421
 
2422
+ user = self._build_image_header(sp) + image_tokens + sp["eoi"] \
2423
+ + tok("\n").input_ids + (tok(question).input_ids if question else [])
2424
+ sys_ids, user_ids, asst_ids = self._build_chat(
2425
+ tok, "You are a multimodal understanding assistant.", user,
2426
+ )
2427
+ ids = sys_ids + user_ids + asst_ids
2428
 
2429
  out = self.generate_bd(
2430
  data={"input_ids": torch.tensor(ids).unsqueeze(0).to(self.device)},
 
2436
  image_keep_ratio=image_keep_ratio, text_keep_ratio=text_keep_ratio,
2437
  show_progress=False,
2438
  )
2439
+ return tok.decode(out[0][len(ids):], skip_special_tokens=True)
2440
 
2441
  @torch.no_grad()
2442
  def edit_image(self, image_tokens, image_h, image_w, instruction,