diff options
Diffstat (limited to 'stable_diffusion.cpp')
-rw-r--r-- | stable_diffusion.cpp | 248 |
1 files changed, 241 insertions, 7 deletions
diff --git a/stable_diffusion.cpp b/stable_diffusion.cpp index 0c42f03..2eb230b 100644 --- a/stable_diffusion.cpp +++ b/stable_diffusion.cpp @@ -1,15 +1,121 @@ +struct curl_data { + char *response; + size_t size; +}; + +static size_t dumbcurlcallback(void *data, size_t size, size_t nmemb, void *userp) + { + size_t realsize = size * nmemb; + curl_data *mem = (curl_data *)userp; + + Memory_Copy((uint8 *)&(mem->response[mem->size]), (uint8 *)data, realsize); + mem->size += realsize; + mem->response[mem->size] = 0; + + return realsize; + } + +static void +SD_JSONToSource(project_data *File, project_state *State, memory *Memory, void *JSONResponse, int Height, int Width) +{ + // We've gotta go from base64 to compressed PNG to the raw format we want. + uint8 *Data = (uint8 *)JSONResponse; + uint64 i = 0; + Assert(Data[2] != 'd'); + while (Data[i] != '[') + i++; + i += 2; + Assert(Data[i] != '"'); + uint64 UncompressedSize = Width * Height * 4; + uint64 c = 0; + while (Data[i+c] != ']') + c++; + uint64 MaxSize = Width * Height * 4; + void *PNGData = Memory_PushScratch(Memory, MaxSize); + uint64 PNGSize = 0; + base64_decode(&Data[i], c, (uint8 *)PNGData, &PNGSize); + int x, y, a; + void *RawData = stbi_load_from_memory((stbi_uc *)PNGData, PNGSize, &x, &y, &a, 4); + Assert(x == Width && y == Height); + Memory_PopScratch(Memory, MaxSize); + int SrcIdx = Source_Generate_Blank(File, State, Memory, Width, Height, 4); + block_source *Source = (block_source *)Memory_Block_AddressAtIndex(Memory, F_Sources, SrcIdx); + // Source->Type = source_type_principal_temp; + void *BitmapAddress = Memory_Block_AddressAtIndex(Memory, F_PrincipalBitmaps, Source->Bitmap_Index, 0); + Memory_Copy((uint8 *)BitmapAddress, (uint8 *)RawData, MaxSize); + stbi_image_free(RawData); +} static void -SD_Txt2Txt(sd_state *SD) +SD_ParseProgress(project_state *State, char *JSONInfo) { - char JSONPayload[1024]; - char CurlCommand[1024]; + char P[4]; + char P2[8]; + int ProgLocation = 0; + while (!(JSONInfo[ProgLocation] == 's' && JSONInfo[ProgLocation+1] == 's')) { + ProgLocation++; + } + ProgLocation += 4; + int ETALocation = ProgLocation; + while (!(JSONInfo[ETALocation] == 'v' && JSONInfo[ETALocation+1] == 'e')) { + ETALocation++; + } + ETALocation += 4; + // Assert(JSONInfo[ProgLocation] >= '0' && JSONInfo[ProgLocation] <= '9'); + // Assert(JSONInfo[ETALocation] >= '0' && JSONInfo[ETALocation] <= '9'); + Memory_Copy((uint8 *)P, (uint8 *)&JSONInfo[12], 4); + Memory_Copy((uint8 *)P2, (uint8 *)&JSONInfo[32], 8); + real32 Percent = atof(P); + if (Percent > 0.0f) // occasionally returns negative + State->SDPercentDone = Percent; + real32 Time = atof(P2); + if (Time > 0.1f) // occasionally returns zero for some reason + State->SDTimeEstimate = Time; + // Assert(0); +} + +static void +JSON_AppendParam_String(char *String, uint64 *i, char *P1, char *P2) +{ + uint64 c = *i; + String[c++] = '"'; + int a = 0; + while(P1[a] != '\0') { + String[c++] = P1[a++]; + } + String[c++] = ':'; + String[c++] = ' '; + String[c++] = '"'; + a = 0; + while(P2[a] != '\0') { + String[c++] = P2[a++]; + if (a > 64) + break; + } + c--; + String[c++] = '"'; + String[c++] = ','; + String[c++] = '\n'; + String[c++] = '\0'; + *i = c; +} + +static void +SD_AssembleJSON(sd_state *SD, char *JSONPayload, void *Base64Bitmap = NULL) +{ + Arbitrary_Zero((uint8 *)JSONPayload, 1024); + // char CurlCommand[1024]; char *Test[] = { "prompt", "negative_prompt", "steps", "width", "height", "cfg_scale" }; void *Test2[6] = { (void *)SD->Prompt, (void *)SD->NegPrompt, (void *)&SD->Steps, (void *)&SD->Width, (void *)&SD->Height, (void *)&SD->CFG }; int Type[6] = { 0, 0, 1, 1, 1, 2}; - sprintf(JSONPayload, "%s{\n", JSONPayload); + JSONPayload[0] = '{'; + JSONPayload[1] = '\n'; + JSONPayload[2] = '\0'; + uint64 i = 2; + if (SD->Mode) + JSON_AppendParam_String(JSONPayload, &i, "init_images", (char *)Base64Bitmap); for (int i = 0; i < 6; i++) { if (Type[i] == 0) { sprintf(JSONPayload, "%s\"%s\": \"%s\",\n", JSONPayload, Test[i], (char *)Test2[i]); @@ -21,8 +127,136 @@ SD_Txt2Txt(sd_state *SD) Assert(0); } } + if (SD->Mode) + sprintf(JSONPayload, "%s\"%s\": %.2f,\n", JSONPayload, "denoising_strength", SD->DenoisingStrength); + sprintf(JSONPayload, "%s%s\n", JSONPayload, "\"sampler_index\": \"DDIM\""); sprintf(JSONPayload, "%s}\n", JSONPayload); - sprintf(CurlCommand, "curl -X POST -H 'Content-Type: application/json' -i '%s/sdapi/v1/txt2img' --data '%s'", - SD->ServerAddress, JSONPayload); - printf("%s\n", CurlCommand); + // sprintf(CurlCommand, "curl -X POST -H 'Content-Type: application/json' -i '%s/sdapi/v1/txt2img' --data '%s'", SD->ServerAddress, JSONPayload); + // printf("%s\n", CurlCommand); }; + +struct curl_state +{ + CURL *curl; + CURLM *curlm; + curl_slist *list = NULL; + curl_data CurlData; +}; + +static void +Curl_Free(curl_state *Handle) +{ + curl_multi_remove_handle(Handle->curlm, Handle->curl); + curl_easy_cleanup(Handle->curl); + curl_multi_cleanup(Handle->curlm); +} + +static bool32 +Curl_Check(curl_state *Handle) +{ + int IsActive; + CURLMcode mc = curl_multi_perform(Handle->curlm, &IsActive); + Assert(!mc); + if (!IsActive) { + int queue = 0; + CURLMsg *msg = curl_multi_info_read(Handle->curlm, &queue); + if (msg) { + CURL *e = msg->easy_handle; + Assert(e == Handle->curl); + Assert(msg->msg == CURLMSG_DONE); + if (!msg->data.result) { + return 1; + } else if (msg->data.result == CURLE_COULDNT_CONNECT) { + // printf("Active stable-diffusion-webui instance not found at URL.\n"); + return -1; + } else { + // printf("curl error: %s!\n", curl_easy_strerror(msg->data.result)); + return -1; + } + } + } + return 0; +} + +static char *APIString[] = {"txt2img", "img2img" }; + +static void +Curl_GET_Init(curl_state *C, void *OutputData, char *JSONPayload, char *IP, bool32 API) +{ + C->list = curl_slist_append(C->list, "Content-Type: application/json"); + + C->curl = curl_easy_init(); + Assert(C->curl); + + char URL[512]; + sprintf(URL, "%s/sdapi/v1/%s", IP, APIString[API]); + curl_easy_setopt(C->curl, CURLOPT_URL, URL); + curl_easy_setopt(C->curl, CURLOPT_HTTPHEADER, C->list); + curl_easy_setopt(C->curl, CURLOPT_POSTFIELDS, JSONPayload); + + C->CurlData = { (char *)OutputData, 0 }; + + curl_easy_setopt(C->curl, CURLOPT_WRITEFUNCTION, dumbcurlcallback); + curl_easy_setopt(C->curl, CURLOPT_WRITEDATA, (void *)&C->CurlData); + + C->curlm = curl_multi_init(); + Assert(C->curlm); + curl_multi_add_handle(C->curlm, C->curl); + + Curl_Check(C); +} + +static void +Curl_Prog_Init(curl_state *C, void *OutputData) +{ + C->curl = curl_easy_init(); + Assert(C->curl); + + C->CurlData = { (char *)OutputData, 0 }; + + curl_easy_setopt(C->curl, CURLOPT_URL, "http://127.0.0.1:7860/sdapi/v1/progress"); + curl_easy_setopt(C->curl, CURLOPT_WRITEFUNCTION, dumbcurlcallback); + curl_easy_setopt(C->curl, CURLOPT_WRITEDATA, (void *)&C->CurlData); + + C->curlm = curl_multi_init(); + Assert(C->curlm); + curl_multi_add_handle(C->curlm, C->curl); +} + + /* + int num; + CURLMcode mc = curl_multi_perform(C->curlm, &num); + Assert(!mc); + + int queue = 0; + CURLMsg *msg = curl_multi_info_read(C->curlm, &queue); + if (msg) { + CURL *e = msg->easy_handle; + if (msg->msg == CURLMSG_DONE) { + if (msg->data.result == CURLE_COULDNT_CONNECT) { + printf("Active stable-diffusion-webui instance not found at URL.\n"); + } else { + printf("curl error: %s!\n", curl_easy_strerror(msg->data.result)); + } + Assert(0); + } + } + */ + + +#if 0 + res = curl_easy_perform(curl); + /* Check for errors */ + if(res != CURLE_OK) + fprintf(stderr, "curl_easy_perform() failed: %s\n", + curl_easy_strerror(res)); + + /* Perform the request, res will get the return code */ + + /* always cleanup */ + curl_slist_free_all(list); + curl_easy_cleanup(curl); + + + curl_global_cleanup(); +#endif |