summaryrefslogtreecommitdiff
path: root/src/stable_diffusion.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/stable_diffusion.cpp')
-rw-r--r--src/stable_diffusion.cpp287
1 files changed, 287 insertions, 0 deletions
diff --git a/src/stable_diffusion.cpp b/src/stable_diffusion.cpp
new file mode 100644
index 0000000..4da327d
--- /dev/null
+++ b/src/stable_diffusion.cpp
@@ -0,0 +1,287 @@
+struct curl_data {
+ char *response;
+ size_t size;
+};
+
+static size_t dumbcurlcallback(void *data, size_t size, size_t nmemb, void *userp)
+ {
+ size_t realsize = size * nmemb;
+ curl_data *mem = (curl_data *)userp;
+
+ Memory_Copy((uint8 *)&(mem->response[mem->size]), (uint8 *)data, realsize);
+ mem->size += realsize;
+ mem->response[mem->size] = 0;
+
+ return realsize;
+ }
+
+static void
+SD_JSONToSource(project_data *File, project_state *State, memory *Memory, void *JSONResponse, int Height, int Width)
+{
+ // We've gotta go from base64 to compressed PNG to the raw format we want.
+ uint8 *Data = (uint8 *)JSONResponse;
+ uint64 i = 0;
+ Assert(Data[2] != 'd');
+ while (Data[i] != '[')
+ i++;
+ i += 2;
+ Assert(Data[i] != '"');
+ uint64 UncompressedSize = Width * Height * 4;
+ uint64 c = 0;
+ while (Data[i+c] != ']')
+ c++;
+ uint64 MaxSize = Width * Height * 4;
+ void *PNGData = Memory_PushScratch(Memory, MaxSize);
+ uint64 PNGSize = 0;
+ base64_decode(&Data[i], c, (uint8 *)PNGData, (size_t *)&PNGSize);
+ int x, y, a;
+ void *RawData = stbi_load_from_memory((stbi_uc *)PNGData, PNGSize, &x, &y, &a, 4);
+ Assert(x == Width && y == Height);
+ Memory_PopScratch(Memory, MaxSize);
+ int32 Highest = 0;
+ {
+ int h = 0, c = 0, i = 0;
+ while (Block_Loop(Memory, F_Sources, File->Source_Count, &h, &c, &i)) {
+ block_source *Source = (block_source *)Memory_Block_AddressAtIndex(Memory, F_Sources, i);
+ if (Source->Type == source_type_principal_temp && Source->RelativeTimestamp > Highest)
+ Highest = Source->RelativeTimestamp;
+ }
+ }
+ int SrcIdx = Source_Generate_Blank(File, State, Memory, Width, Height, 4);
+ block_source *Source = (block_source *)Memory_Block_AddressAtIndex(Memory, F_Sources, SrcIdx);
+ Source->Type = source_type_principal_temp;
+ Source->RelativeTimestamp = Highest + 1;
+ void *BitmapAddress = Memory_Block_AddressAtIndex(Memory, F_PrincipalBitmaps, Source->Bitmap_Index, 0);
+ Memory_Copy((uint8 *)BitmapAddress, (uint8 *)RawData, MaxSize);
+ stbi_image_free(RawData);
+}
+
+static void
+SD_ParseProgress(project_state *State, char *JSONInfo)
+{
+ char P[4];
+ char P2[8];
+ int ProgLocation = 0;
+ while (!(JSONInfo[ProgLocation] == 's' && JSONInfo[ProgLocation+1] == 's')) {
+ ProgLocation++;
+ }
+ ProgLocation += 4;
+ int ETALocation = ProgLocation;
+ while (!(JSONInfo[ETALocation] == 'v' && JSONInfo[ETALocation+1] == 'e')) {
+ ETALocation++;
+ }
+ ETALocation += 4;
+ // Assert(JSONInfo[ProgLocation] >= '0' && JSONInfo[ProgLocation] <= '9');
+ // Assert(JSONInfo[ETALocation] >= '0' && JSONInfo[ETALocation] <= '9');
+ Memory_Copy((uint8 *)P, (uint8 *)&JSONInfo[12], 4);
+ Memory_Copy((uint8 *)P2, (uint8 *)&JSONInfo[32], 8);
+ real32 Percent = atof(P);
+ if (Percent > 0.0f) // occasionally returns negative
+ State->SDPercentDone = Percent;
+ real32 Time = atof(P2);
+ if (Time > 0.1f) // occasionally returns zero for some reason
+ State->SDTimeEstimate = Time;
+ // Assert(0);
+}
+
+static char *pre = "data:image/png;base64,";
+
+static void
+JSON_AppendParam_String(char *String, uint64 *i, char *P1, char *P2)
+{
+ uint64 c = *i;
+ String[c++] = '"';
+ int a = 0;
+ while(P1[a] != '\0') {
+ String[c++] = P1[a++];
+ }
+ String[c++] = '"';
+ String[c++] = ':';
+ String[c++] = ' ';
+ String[c++] = '[';
+ String[c++] = '"';
+ a = 0;
+ while(pre[a] != '\0') {
+ String[c++] = pre[a++];
+ }
+ a = 0;
+ while(P2[a] != '\0') {
+ String[c++] = P2[a++];
+ }
+ c--;
+ String[c++] = '"';
+ String[c++] = ']';
+ String[c++] = ',';
+ String[c++] = '\n';
+ String[c++] = '\0';
+ *i = c;
+}
+
+static void
+SD_AssembleJSON(sd_state *SD, char *JSONPayload, void *Base64Bitmap = NULL)
+{
+ Arbitrary_Zero((uint8 *)JSONPayload, 1024);
+ // char CurlCommand[1024];
+ char *Test[] = { "prompt", "negative_prompt", "steps", "width", "height", "cfg_scale" };
+ void *Test2[6] = { (void *)SD->Prompt, (void *)SD->NegPrompt,
+ (void *)&SD->Steps, (void *)&SD->Width,
+ (void *)&SD->Height, (void *)&SD->CFG };
+ int Type[6] = { 0, 0, 1, 1, 1, 2};
+ JSONPayload[0] = '{';
+ JSONPayload[1] = '\n';
+ JSONPayload[2] = '\0';
+ uint64 i = 2;
+ if (SD->Mode) {
+ JSON_AppendParam_String(JSONPayload, &i, "init_images", (char *)Base64Bitmap);
+ }
+ for (int i = 0; i < 6; i++) {
+ if (Type[i] == 0) {
+ sprintf(JSONPayload, "%s\"%s\": \"%s\",\n", JSONPayload, Test[i], (char *)Test2[i]);
+ } else if (Type[i] == 1) {
+ sprintf(JSONPayload, "%s\"%s\": %i,\n", JSONPayload, Test[i], *(int *)Test2[i]);
+ } else if (Type[i] == 2) {
+ sprintf(JSONPayload, "%s\"%s\": %.2f,\n", JSONPayload, Test[i], *(real32 *)Test2[i]);
+ } else {
+ Assert(0);
+ }
+ }
+ if (SD->Mode)
+ sprintf(JSONPayload, "%s\"%s\": %.2f,\n", JSONPayload, "denoising_strength", SD->DenoisingStrength);
+ sprintf(JSONPayload, "%s%s\n", JSONPayload, "\"sampler_index\": \"DPM++ 2S a Karras\"");
+ sprintf(JSONPayload, "%s}\n", JSONPayload);
+ printf("%s\n", JSONPayload);
+ // sprintf(CurlCommand, "curl -X POST -H 'Content-Type: application/json' -i '%s/sdapi/v1/txt2img' --data '%s'", SD->ServerAddress, JSONPayload);
+ // printf("%s\n", CurlCommand);
+};
+
+struct curl_state
+{
+ CURL *curl;
+ CURLM *curlm;
+ curl_slist *list = NULL;
+ curl_data CurlData;
+};
+
+static void
+Curl_Free(curl_state *Handle)
+{
+ curl_multi_remove_handle(Handle->curlm, Handle->curl);
+ curl_easy_cleanup(Handle->curl);
+ curl_multi_cleanup(Handle->curlm);
+}
+
+static void
+Curl_StopAll(project_state *State, curl_state *ProgHandle, curl_state *MainHandle)
+{
+ Curl_Free(ProgHandle);
+ curl_slist_free_all(ProgHandle->list);
+ Curl_Free(MainHandle);
+ State->CurlActive = 0;
+}
+
+static int
+Curl_Check(curl_state *Handle)
+{
+ int IsActive;
+ CURLMcode mc = curl_multi_perform(Handle->curlm, &IsActive);
+ Assert(!mc);
+ if (!IsActive) {
+ int queue = 0;
+ CURLMsg *msg = curl_multi_info_read(Handle->curlm, &queue);
+ if (msg) {
+ CURL *e = msg->easy_handle;
+ Assert(e == Handle->curl);
+ Assert(msg->msg == CURLMSG_DONE);
+ if (!msg->data.result) {
+ return 1;
+ } else if (msg->data.result == CURLE_COULDNT_CONNECT) {
+ return -1;
+ } else {
+ printf("curl error: %s!\n", curl_easy_strerror(msg->data.result));
+ return -2;
+ }
+ }
+ }
+ return 0;
+}
+
+static char *APIString[] = {"txt2img", "img2img" };
+
+static void
+Curl_GET_Init(curl_state *C, void *OutputData, char *JSONPayload, char *IP, bool32 API)
+{
+ C->list = curl_slist_append(C->list, "Content-Type: application/json");
+
+ C->curl = curl_easy_init();
+ Assert(C->curl);
+
+ char URL[512];
+ sprintf(URL, "%s/sdapi/v1/%s", IP, APIString[API]);
+ curl_easy_setopt(C->curl, CURLOPT_URL, URL);
+ curl_easy_setopt(C->curl, CURLOPT_HTTPHEADER, C->list);
+ curl_easy_setopt(C->curl, CURLOPT_POSTFIELDS, JSONPayload);
+
+ C->CurlData = { (char *)OutputData, 0 };
+
+ curl_easy_setopt(C->curl, CURLOPT_WRITEFUNCTION, dumbcurlcallback);
+ curl_easy_setopt(C->curl, CURLOPT_WRITEDATA, (void *)&C->CurlData);
+
+ C->curlm = curl_multi_init();
+ Assert(C->curlm);
+ curl_multi_add_handle(C->curlm, C->curl);
+
+ Curl_Check(C);
+}
+
+static void
+Curl_Prog_Init(curl_state *C, void *OutputData)
+{
+ C->curl = curl_easy_init();
+ Assert(C->curl);
+
+ C->CurlData = { (char *)OutputData, 0 };
+
+ curl_easy_setopt(C->curl, CURLOPT_URL, "http://127.0.0.1:7860/sdapi/v1/progress");
+ curl_easy_setopt(C->curl, CURLOPT_WRITEFUNCTION, dumbcurlcallback);
+ curl_easy_setopt(C->curl, CURLOPT_WRITEDATA, (void *)&C->CurlData);
+
+ C->curlm = curl_multi_init();
+ Assert(C->curlm);
+ curl_multi_add_handle(C->curlm, C->curl);
+}
+
+static void
+Curl_Main(project_data *File, project_state *State, memory *Memory, curl_state *MainHandle, curl_state *ProgHandle)
+{
+ if (State->CurlActive == -1) {
+ Curl_GET_Init(MainHandle, State->Dump1, State->JSONPayload, File->UI.SD.ServerAddress, File->UI.SD.Mode);
+ Curl_Prog_Init(ProgHandle, State->Dump2);
+ State->CurlActive = 1;
+ } else {
+ if (Curl_Check(MainHandle) == 1) {
+ SD_JSONToSource(File, State, Memory, State->Dump1, File->UI.SD.Height, File->UI.SD.Width);
+ Curl_StopAll(State, ProgHandle, MainHandle);
+ }
+ uint64 Time = ImGui::GetTime();
+ if (Time - State->SDTimer > 0.3f) {
+ int Test = Curl_Check(ProgHandle);
+ if (Test == 1) {
+ SD_ParseProgress(State, (char *)State->Dump2);
+ curl_multi_remove_handle(ProgHandle->curlm, ProgHandle->curl);
+ curl_easy_reset(ProgHandle->curl);
+ ProgHandle->CurlData.size = 0;
+ curl_easy_setopt(ProgHandle->curl, CURLOPT_URL, "http://127.0.0.1:7860/sdapi/v1/progress");
+ curl_easy_setopt(ProgHandle->curl, CURLOPT_WRITEFUNCTION, dumbcurlcallback);
+ curl_easy_setopt(ProgHandle->curl, CURLOPT_WRITEDATA, (void *)&ProgHandle->CurlData);
+ curl_multi_add_handle(ProgHandle->curlm, ProgHandle->curl);
+ } else if (Test == -1) {
+ PostMsg(State, "Active stable-diffusion-webui instance not found at URL.");
+ Curl_StopAll(State, ProgHandle, MainHandle);
+ } else if (Test == -2) {
+ PostMsg(State, "CURL error; see command line.");
+ Curl_StopAll(State, ProgHandle, MainHandle);
+ }
+ State->SDTimer = Time;
+ }
+ }
+}