Skip to content

Commit 1ffa1cc

Browse files
committed
add initial lora support
1 parent 662610b commit 1ffa1cc

File tree

1 file changed

+62
-2
lines changed

1 file changed

+62
-2
lines changed

examples/server/client.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,37 @@
2222

2323
apis = ['sdapi', 'openai', 'openai-module']
2424

25+
def extract_lora_tags(prompt):
26+
27+
pattern = r'<lora:([^:>]+):([^>]+)>'
28+
lora_data = []
29+
30+
matches = list(re.finditer(pattern, prompt))
31+
32+
for match in matches:
33+
raw_path = match.group(1)
34+
raw_mul = match.group(2)
35+
try:
36+
mul = float(raw_mul)
37+
except ValueError:
38+
continue
39+
40+
is_high_noise = False
41+
prefix = "|high_noise|"
42+
if raw_path.startswith(prefix):
43+
raw_path = raw_path[len(prefix):]
44+
is_high_noise = True
45+
46+
lora_data.append({
47+
'name': raw_path,
48+
'multiplier': mul,
49+
'is_high_noise': is_high_noise,
50+
})
51+
52+
prompt = prompt.replace(match.group(0), "", 1)
53+
54+
return prompt, lora_data
55+
2556

2657
def parse_arguments():
2758
ap = argparse.ArgumentParser(
@@ -137,8 +168,13 @@ def parse_arguments():
137168
if not args_dict.get("server_url", "").strip():
138169
ap.error("--server-url not provided and SD_SERVER_URL env var not found")
139170

140-
if not args_dict.get("prompt", "").strip():
171+
prompt = args_dict.get("prompt", "").strip()
172+
prompt, lora = extract_lora_tags(prompt)
173+
if not prompt:
141174
ap.error("argument -p/--prompt must be non‑empty")
175+
args_dict["prompt"] = prompt
176+
if lora:
177+
args_dict["lora"] = lora
142178

143179
util_keys = {'verbose', 'server_url', 'output', 'output_begin_idx', 'api', 'init_img', 'mask', 'ref_image', 'output_format'}
144180

@@ -352,7 +388,10 @@ def truncate_for_debug(obj, max_length=512):
352388
def do_request(url, data, headers):
353389

354390
if HAS_REQUESTS:
355-
response = requests.post(url, headers=headers, data=data, timeout=30)
391+
response = requests.post(url, headers=headers, data=data, timeout=600)
392+
if response.status_code != 200:
393+
print(f"HTTP {response.status_code}: {response.reason}")
394+
print(f" {response.text}")
356395
response.raise_for_status()
357396
return response.text
358397

@@ -396,6 +435,27 @@ def main_sdapi(util_opts, gen_opts, verbose=False):
396435
endpoint = urllib.parse.urljoin(server_url, "sdapi/v1/txt2img")
397436

398437
api_parameters = build_sdapi_parameters(gen_opts, util_opts, image_opts)
438+
lora = gen_opts.get('lora')
439+
if lora:
440+
# TODO: refactor, error handling, is_high_noise
441+
lora_list = json.loads(requests.get(urllib.parse.urljoin(server_url, "sdapi/v1/loras")).text)
442+
#print(f"remote lora list: {json.dumps(lora_list, indent=2)}")
443+
lora_map = {l.get("name"): l.get("path") for l in lora_list}
444+
req_lora = []
445+
print(f"requesting LoRAs")
446+
for llora in lora:
447+
name = llora['name']
448+
rlora = lora_map.get(name)
449+
if rlora:
450+
entry = {'path': rlora, 'multiplier': llora["multiplier"]}
451+
if verbose:
452+
print(f" lora '{name}' mapped to '{rlora}'")
453+
req_lora.append(entry)
454+
else:
455+
if verbose:
456+
print(f" warning: lora '{name}' not found on remote")
457+
if req_lora:
458+
api_parameters["lora"] = req_lora
399459

400460
if verbose:
401461
print(f"Using sdapi")

0 commit comments

Comments
 (0)